Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate TypedDicts for python inputs #15957

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
changes:
- type: feat
scope: sdkgen/python
description: Generate TypedDict types for inputs
4 changes: 2 additions & 2 deletions pkg/codegen/python/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (d DocLanguageHelper) GetLanguageTypeString(pkg *schema.Package, moduleName
mod: moduleName,
typeDetails: typeDetails,
}
typeName := mod.typeString(t, input, false /*acceptMapping*/)
typeName := mod.typeString(t, input, false /*acceptMapping*/, false /*forDict*/)

// Remove any package qualifiers from the type name.
if !input {
Expand Down Expand Up @@ -125,7 +125,7 @@ func (d DocLanguageHelper) GetMethodResultName(pkg *schema.Package, modName stri
mod: modName,
typeDetails: typeDetails,
}
return mod.typeString(returnType.Properties[0].Type, false, false)
return mod.typeString(returnType.Properties[0].Type, false, false, false /*forDict*/)
}
}
return fmt.Sprintf("%s.%sResult", resourceName(r), title(d.GetMethodName(m)))
Expand Down
148 changes: 123 additions & 25 deletions pkg/codegen/python/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ func title(s string) string {
}

type modLocator struct {
// Returns defining modlue for a given ObjectType. Returns nil
// Returns defining module for a given ObjectType. Returns nil
// for types that are not being generated in the current
// GeneratePacakge call.
// GeneratePackage call.
objectTypeMod func(*schema.ObjectType) *modContext
}

Expand All @@ -128,6 +128,9 @@ type modContext struct {

// Determine whether to lift single-value method return values
liftSingleValueMethodReturns bool

// Emit TypedDicts types for inputs
typedDictArgs bool
}

func (mod *modContext) isTopLevel() bool {
Expand Down Expand Up @@ -219,7 +222,7 @@ func (mod *modContext) unqualifiedObjectTypeName(t *schema.ObjectType, input boo
return name
}

func (mod *modContext) objectType(t *schema.ObjectType, input bool) string {
func (mod *modContext) objectType(t *schema.ObjectType, input bool, forDict bool) string {
var prefix string
if !input {
prefix = "outputs."
Expand All @@ -233,6 +236,9 @@ func (mod *modContext) objectType(t *schema.ObjectType, input bool) string {
}

modName, name := mod.tokenToModule(t.Token), mod.unqualifiedObjectTypeName(t, input)
if forDict {
name = name + "Dict"
}
if modName == "" && modName != mod.mod {
rootModName := "_root_outputs."
if input {
Expand Down Expand Up @@ -407,9 +413,18 @@ func (mod *modContext) generateCommonImports(w io.Writer, imports imports, typin

fmt.Fprintf(w, "import copy\n")
fmt.Fprintf(w, "import warnings\n")
if mod.typedDictArgs {
fmt.Fprintf(w, "import sys\n")
}
fmt.Fprintf(w, "import pulumi\n")
fmt.Fprintf(w, "import pulumi.runtime\n")
fmt.Fprintf(w, "from typing import %s\n", strings.Join(typingImports, ", "))
if mod.typedDictArgs {
fmt.Fprintf(w, "if sys.version_info >= (3, 11):\n")
fmt.Fprintf(w, " from typing import NotRequired, TypedDict, TypeAlias\n")
fmt.Fprintf(w, "else:\n")
fmt.Fprintf(w, " from typing_extensions import NotRequired, TypedDict, TypeAlias\n")
}
fmt.Fprintf(w, "from %s import _utilities\n", relImport)
for _, imp := range imports.strings() {
fmt.Fprintf(w, "%s\n", imp)
Expand Down Expand Up @@ -1053,9 +1068,16 @@ func (mod *modContext) genTypes(dir string, fs codegen.Fs) error {
if input && mod.details(t).inputType || !input && mod.details(t).outputType {
fmt.Fprintf(w, " '%s',\n", mod.unqualifiedObjectTypeName(t, input))
}
if input && mod.typedDictArgs && mod.details(t).inputType {
fmt.Fprintf(w, " '%sDict',\n", mod.unqualifiedObjectTypeName(t, input))
}
}
fmt.Fprintf(w, "]\n\n")

if input && mod.typedDictArgs {
fmt.Fprintf(w, "MYPY = False\n\n")
}

var hasTypes bool
for _, t := range mod.types {
if t.IsOverlay {
Expand Down Expand Up @@ -1126,7 +1148,7 @@ func (mod *modContext) genAwaitableType(w io.Writer, obj *schema.ObjectType) str
// Note that deprecation messages will be emitted on access to the property, rather than initialization.
// This avoids spamming end users with irrelevant deprecation messages.
mod.genProperties(w, obj.Properties, false /*setters*/, "", func(prop *schema.Property) string {
return mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/)
return mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/, false /*forDict*/)
})

// Produce an awaitable subclass.
Expand Down Expand Up @@ -1281,7 +1303,7 @@ func (mod *modContext) genResource(res *schema.Resource) (string, error) {

// If there's an argument type, emit it.
for _, prop := range res.InputProperties {
ty := mod.typeString(codegen.OptionalType(prop), true, true /*acceptMapping*/)
ty := mod.typeString(codegen.OptionalType(prop), true, true /*acceptMapping*/, false /*forDict*/)
fmt.Fprintf(w, ",\n %s: %s = None", InitParamName(prop.Name), ty)
}

Expand Down Expand Up @@ -1467,7 +1489,7 @@ func (mod *modContext) genResource(res *schema.Resource) (string, error) {
if hasStateInputs {
for _, prop := range res.StateInputs.Properties {
pname := InitParamName(prop.Name)
ty := mod.typeString(codegen.OptionalType(prop), true, true /*acceptMapping*/)
ty := mod.typeString(codegen.OptionalType(prop), true, true /*acceptMapping*/, false /*forDict*/)
fmt.Fprintf(w, ",\n %s: %s = None", pname, ty)
}
}
Expand Down Expand Up @@ -1502,7 +1524,7 @@ func (mod *modContext) genResource(res *schema.Resource) (string, error) {

// Write out Python property getters for each of the resource's properties.
mod.genProperties(w, res.Properties, false /*setters*/, "", func(prop *schema.Property) string {
ty := mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/)
ty := mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/, false /*forDict*/)
return fmt.Sprintf("pulumi.Output[%s]", ty)
})

Expand Down Expand Up @@ -1606,7 +1628,7 @@ func (mod *modContext) genMethodReturnType(w io.Writer, method *schema.Method) s
// Note that deprecation messages will be emitted on access to the property, rather than initialization.
// This avoids spamming end users with irrelevant deprecation messages.
mod.genProperties(w, properties, false /*setters*/, " ", func(prop *schema.Property) string {
return mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/)
return mod.typeString(prop.Type, false /*input*/, false /*acceptMapping*/, false /*forDict*/)
})

return name
Expand Down Expand Up @@ -1674,7 +1696,7 @@ func (mod *modContext) genMethods(w io.Writer, res *schema.Resource) {
}
for _, arg := range args {
pname := PyName(arg.Name)
ty := mod.typeString(arg.Type, true, false /*acceptMapping*/)
ty := mod.typeString(arg.Type, true, false /*acceptMapping*/, false /*forDict*/)
var defaultValue string
if !arg.IsRequired() {
defaultValue = " = None"
Expand Down Expand Up @@ -1909,7 +1931,7 @@ func (mod *modContext) genFunDef(w io.Writer, name, retTypeName string, args []*
argType = codegen.OptionalType(arg)
}

ty := mod.typeString(argType, true /*input*/, true /*acceptMapping*/)
ty := mod.typeString(argType, true /*input*/, true /*acceptMapping*/, false /*forDict*/)
fmt.Fprintf(w, "%s%s: %s = None,\n", ind, pname, ty)
}
fmt.Fprintf(w, "%sopts: Optional[pulumi.InvokeOptions] = None", indent)
Expand Down Expand Up @@ -1989,7 +2011,7 @@ func (mod *modContext) genEnums(w io.Writer, enums []*schema.EnumType) error {
func (mod *modContext) genEnum(w io.Writer, enum *schema.EnumType) error {
indent := " "
enumName := tokenToName(enum.Token)
underlyingType := mod.typeString(enum.ElementType, false, false)
underlyingType := mod.typeString(enum.ElementType, false, false, false /*forDict*/)

switch enum.ElementType {
case schema.StringType, schema.IntType, schema.NumberType:
Expand Down Expand Up @@ -2326,7 +2348,7 @@ func (mod *modContext) genPropDocstring(w io.Writer, name string, prop *schema.P
return
}

ty := mod.typeString(codegen.RequiredType(prop), true, acceptMapping)
ty := mod.typeString(codegen.RequiredType(prop), true, acceptMapping, false /*forDict*/)

// If this property has some documentation associated with it, we need to split it so that it is indented
// in a way that Sphinx can understand.
Expand All @@ -2345,34 +2367,51 @@ func (mod *modContext) genPropDocstring(w io.Writer, name string, prop *schema.P
}
}

func (mod *modContext) typeString(t schema.Type, input, acceptMapping bool) string {
func (mod *modContext) typeString(t schema.Type, input, acceptMapping bool, forDict bool) string {
switch t := t.(type) {
case *schema.OptionalType:
return fmt.Sprintf("Optional[%s]", mod.typeString(t.ElementType, input, acceptMapping))
typ := mod.typeString(t.ElementType, input, acceptMapping, forDict)
if forDict {
return fmt.Sprintf("NotRequired[%s]", typ)
}
return fmt.Sprintf("Optional[%s]", typ)
case *schema.InputType:
typ := mod.typeString(codegen.SimplifyInputUnion(t.ElementType), input, acceptMapping)
typ := mod.typeString(codegen.SimplifyInputUnion(t.ElementType), input, acceptMapping, forDict)
if typ == "Any" {
return typ
}
return fmt.Sprintf("pulumi.Input[%s]", typ)
case *schema.EnumType:
return mod.enumType(t)
case *schema.ArrayType:
return fmt.Sprintf("Sequence[%s]", mod.typeString(t.ElementType, input, acceptMapping))
return fmt.Sprintf("Sequence[%s]", mod.typeString(t.ElementType, input, acceptMapping, forDict))
case *schema.MapType:
return fmt.Sprintf("Mapping[str, %s]", mod.typeString(t.ElementType, input, acceptMapping))
return fmt.Sprintf("Mapping[str, %s]", mod.typeString(t.ElementType, input, acceptMapping, forDict))
case *schema.ObjectType:
typ := mod.objectType(t, input)
if forDict {
return mod.objectType(t, input, true /*dictType*/)
}
typ := mod.objectType(t, input, false /*dictType*/)
if !acceptMapping {
return typ
}
// If the type is an input and the TypedDict generation is enabled for the type's package, we
// we can emit `Union[type, dictType]` and avoid the `InputType[]` wrapper.
// dictType covers the Mapping case in `InputType = Union[T, Mapping[str, Any]]`.
pkg, err := t.PackageReference.Definition()
contract.AssertNoErrorf(err, "error loading definition for package %q", t.PackageReference.Name())
info, ok := pkg.Language["python"].(PackageInfo)
hasTypedDictArgs := ok && info.TypedDictArgs
if hasTypedDictArgs && input {
return fmt.Sprintf("Union[%s, %s]", typ, mod.objectType(t, input, true /*dictType*/))
}
return fmt.Sprintf("pulumi.InputType[%s]", typ)
case *schema.ResourceType:
return fmt.Sprintf("'%s'", mod.resourceType(t))
case *schema.TokenType:
// Use the underlying type for now.
if t.UnderlyingType != nil {
return mod.typeString(t.UnderlyingType, input, acceptMapping)
return mod.typeString(t.UnderlyingType, input, acceptMapping, forDict)
}
return "Any"
case *schema.UnionType:
Expand All @@ -2381,19 +2420,19 @@ func (mod *modContext) typeString(t schema.Type, input, acceptMapping bool) stri
// If this is an output and a "relaxed" enum, emit the type as the underlying primitive type rather than the union.
// Eg. Output[str] rather than Output[Any]
if typ, ok := e.(*schema.EnumType); ok {
return mod.typeString(typ.ElementType, input, acceptMapping)
return mod.typeString(typ.ElementType, input, acceptMapping, forDict)
}
}
if t.DefaultType != nil {
return mod.typeString(t.DefaultType, input, acceptMapping)
return mod.typeString(t.DefaultType, input, acceptMapping, forDict)
}
return "Any"
}

elementTypeSet := codegen.NewStringSet()
elements := slice.Prealloc[string](len(t.ElementTypes))
for _, e := range t.ElementTypes {
et := mod.typeString(e, input, acceptMapping)
et := mod.typeString(e, input, acceptMapping, forDict)
if !elementTypeSet.Has(et) {
elementTypeSet.Add(et)
elements = append(elements, et)
Expand Down Expand Up @@ -2501,6 +2540,11 @@ func InitParamName(name string) string {
func (mod *modContext) genObjectType(w io.Writer, obj *schema.ObjectType, input bool) error {
name := mod.unqualifiedObjectTypeName(obj, input)
resourceOutputType := !input && mod.details(obj).resourceOutputType
if input && mod.typedDictArgs {
if err := mod.genDictType(w, name, obj.Comment, obj.Properties); err != nil {
return err
}
}
return mod.genType(w, name, obj.Comment, obj.Properties, input, resourceOutputType)
}

Expand Down Expand Up @@ -2583,9 +2627,9 @@ func (mod *modContext) genType(w io.Writer, name, comment string, properties []*
}
for _, prop := range props {
pname := PyName(prop.Name)
ty := mod.typeString(prop.Type, input, false /*acceptMapping*/)
ty := mod.typeString(prop.Type, input, false /*acceptMapping*/, false /*forDict*/)
if prop.DefaultValue != nil {
ty = mod.typeString(codegen.OptionalType(prop), input, false /*acceptMapping*/)
ty = mod.typeString(codegen.OptionalType(prop), input, false /*acceptMapping*/, false /*forDict*/)
}

var defaultValue string
Expand Down Expand Up @@ -2644,9 +2688,61 @@ func (mod *modContext) genType(w io.Writer, name, comment string, properties []*

// Generate properties. Input types have getters and setters, output types only have getters.
mod.genProperties(w, props, input /*setters*/, "", func(prop *schema.Property) string {
return mod.typeString(prop.Type, input, false /*acceptMapping*/)
return mod.typeString(prop.Type, input, false /*acceptMapping*/, false /*forDict*/)
})

fmt.Fprintf(w, "\n")
return nil
}

func (mod *modContext) genDictType(w io.Writer, name, comment string, properties []*schema.Property) error {
// Sort required props first.
props := make([]*schema.Property, len(properties))
copy(props, properties)
sort.Slice(props, func(i, j int) bool {
pi, pj := props[i], props[j]
switch {
case pi.IsRequired() != pj.IsRequired():
return pi.IsRequired() && !pj.IsRequired()
default:
return pi.Name < pj.Name
}
})

indent := " "
name = pythonCase(name)

// Running mypy gets very slow when there are a lot of TypedDicts.
// https://github.com/python/mypy/issues/17231
// For now we only use the TypedDict types when using a typechecker
// other than mypy. For mypy we define the XXXArgsDict class as an alias
// to the type `Mapping[str, Any]`.
fmt.Fprintf(w, "if not MYPY:\n")
fmt.Fprintf(w, "%sclass %sDict(TypedDict):\n", indent, name)

indent += " "

if comment != "" {
printComment(w, comment, indent)
}

for _, prop := range props {
pname := PyName(prop.Name)
ty := mod.typeString(prop.Type, true /*input*/, false /*acceptMapping*/, true /*forDict*/)
fmt.Fprintf(w, "%s%s: %s\n", indent, pname, ty)
if prop.Comment != "" {
printComment(w, prop.Comment, indent)
}
}

if len(props) == 0 {
fmt.Fprintf(w, "%spass\n", indent)
}

indent = " "
fmt.Fprintf(w, "elif False:\n")
fmt.Fprintf(w, "%s%sDict: TypeAlias = Mapping[str, Any]\n", indent, name)

fmt.Fprintf(w, "\n")
return nil
}
Expand Down Expand Up @@ -2742,6 +2838,7 @@ func generateModuleContextMap(tool string, pkg *schema.Package, info PackageInfo
modNameOverrides: info.ModuleNameOverrides,
compatibility: info.Compatibility,
liftSingleValueMethodReturns: info.LiftSingleValueMethodReturns,
typedDictArgs: info.TypedDictArgs,
}

if modName != "" && codegen.PkgEquals(p, pkg.Reference()) {
Expand Down Expand Up @@ -3226,6 +3323,7 @@ func calculateDeps(requires map[string]string) ([][2]string, error) {
deps := []string{
"semver>=2.8.1",
"parver>=0.2.1",
"typing-extensions>=4.11; python_version < \"3.11\"",
}
for dep := range requires {
deps = append(deps, dep)
Expand Down
4 changes: 4 additions & 0 deletions pkg/codegen/python/gen_program.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,10 @@ func (g *generator) argumentTypeName(expr model.Expression, destType model.Type)
if m, ok := pkgInfo.ModuleNameOverrides[module]; ok {
modName = m
}
if pkgInfo.TypedDictArgs {
// Package supports TypedDicts, return an empty string so we use a dict instead of the Args class.
return ""
}
}
}
return tokenToQualifiedName(pkgName, modName, member) + "Args"
Expand Down