Skip to content

Commit

Permalink
feat: object TypeExtension
Browse files Browse the repository at this point in the history
  • Loading branch information
Kitefiko committed May 3, 2024
1 parent 65142b8 commit 367f8af
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 16 deletions.
5 changes: 4 additions & 1 deletion strawberry/experimental/pydantic/error_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
normalize_type,
)
from strawberry.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_extension import TypeExtension
from strawberry.types.type_resolver import _get_fields
from strawberry.utils.typing import get_list_annotation, is_list

Expand Down Expand Up @@ -114,7 +115,9 @@ def wrap(cls: Type) -> Type:
]

wrapped = _wrap_dataclass(cls)
extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped, {}))
extra_fields = cast(
List[dataclasses.Field], _get_fields(wrapped, TypeExtension(), {})
)
private_fields = get_private_fields(wrapped)

all_model_fields.extend(
Expand Down
3 changes: 2 additions & 1 deletion strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from strawberry.field import StrawberryField
from strawberry.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_extension import TypeExtension
from strawberry.types.type_resolver import _get_fields
from strawberry.utils.dataclasses import add_custom_init_fn

Expand Down Expand Up @@ -177,7 +178,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
)

wrapped = _wrap_dataclass(cls)
extra_strawberry_fields = _get_fields(wrapped, {})
extra_strawberry_fields = _get_fields(wrapped, TypeExtension(), {})
extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields)
private_fields = get_private_fields(wrapped)

Expand Down
3 changes: 2 additions & 1 deletion strawberry/federation/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from strawberry.field import StrawberryField, field
from strawberry.object_type import _wrap_dataclass
from strawberry.schema_directive import Location, StrawberrySchemaDirective
from strawberry.types.type_extension import TypeExtension
from strawberry.types.type_resolver import _get_fields


Expand Down Expand Up @@ -39,7 +40,7 @@ def schema_directive(
) -> Callable[..., T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls)
fields = _get_fields(cls, {})
fields = _get_fields(cls, TypeExtension(), {})

cls.__strawberry_directive__ = StrawberryFederationSchemaDirective(
python_name=cls.__name__,
Expand Down
28 changes: 22 additions & 6 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
)
from .field import StrawberryField, field
from .type import get_object_definition
from .types.type_extension import TypeExtension
from .types.type_resolver import _get_fields
from .types.types import (
StrawberryObjectDefinition,
)
from .types.types import StrawberryObjectDefinition
from .utils.dataclasses import add_custom_init_fn
from .utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor
from .utils.str_converters import to_camel_case
Expand Down Expand Up @@ -135,17 +134,19 @@ def _process_type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
original_type_annotations: Optional[Dict[str, Any]] = None,
) -> T:
name = name or to_camel_case(cls.__name__)
extension = extension or TypeExtension()
original_type_annotations = original_type_annotations or {}

interfaces = _get_interfaces(cls)
fields = _get_fields(cls, original_type_annotations)
fields = _get_fields(cls, extension, original_type_annotations)
is_type_of = getattr(cls, "is_type_of", None)
resolve_type = getattr(cls, "resolve_type", None)

cls.__strawberry_definition__ = StrawberryObjectDefinition(
cls.__strawberry_definition__ = extension.create_object_definition(
name=name,
is_input=is_input,
is_interface=is_interface,
Expand Down Expand Up @@ -186,7 +187,7 @@ def _process_type(

setattr(cls, field_.python_name, wrapped_func)

return cls
return extension.after_process(cls) # type: ignore


@overload
Expand All @@ -202,6 +203,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> T: ...


Expand All @@ -217,6 +219,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]: ...


Expand All @@ -229,6 +232,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> Union[T, Callable[[T], T]]:
"""Annotates a class as a GraphQL type.
Expand All @@ -238,6 +242,8 @@ def type(
>>> class X:
>>> field_abc: str = "ABC"
"""
if extension is None:
extension = TypeExtension()

def wrap(cls: Type) -> T:
if not inspect.isclass(cls):
Expand Down Expand Up @@ -266,6 +272,7 @@ def wrap(cls: Type) -> T:
if field and isinstance(field, StrawberryField) and field.type_annotation:
original_type_annotations[field_name] = field.type_annotation.annotation

extension.before_wrap_dataclass(cls)
wrapped = _wrap_dataclass(cls)

return _process_type(
Expand All @@ -276,6 +283,7 @@ def wrap(cls: Type) -> T:
description=description,
directives=directives,
extend=extend,
extension=extension,
original_type_annotations=original_type_annotations,
)

Expand All @@ -295,6 +303,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> T: ...


Expand All @@ -307,6 +316,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]: ...


Expand All @@ -316,6 +326,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
):
"""Annotates a class as a GraphQL Input type.
Example usage:
Expand All @@ -330,6 +341,7 @@ def input(
description=description,
directives=directives,
is_input=True,
extension=extension,
)


Expand All @@ -343,6 +355,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> T: ...


Expand All @@ -355,6 +368,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]: ...


Expand All @@ -367,6 +381,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
):
"""Annotates a class as a GraphQL Interface.
Example usage:
Expand All @@ -381,6 +396,7 @@ def interface(
description=description,
directives=directives,
is_interface=True,
extension=extension,
)


Expand Down
3 changes: 2 additions & 1 deletion strawberry/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing_extensions import dataclass_transform

from strawberry.object_type import _wrap_dataclass
from strawberry.types.type_extension import TypeExtension
from strawberry.types.type_resolver import _get_fields

from .directive import directive_field
Expand Down Expand Up @@ -54,7 +55,7 @@ def schema_directive(
) -> Callable[..., T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls)
fields = _get_fields(cls, {})
fields = _get_fields(cls, TypeExtension(), {})

cls.__strawberry_directive__ = StrawberrySchemaDirective(
python_name=cls.__name__,
Expand Down
59 changes: 59 additions & 0 deletions strawberry/types/type_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence

from strawberry.types.types import StrawberryObjectDefinition

if TYPE_CHECKING:
from dataclasses import Field

from graphql import GraphQLAbstractType, GraphQLResolveInfo

from strawberry.field import StrawberryField
from strawberry.type import WithStrawberryObjectDefinition


class TypeExtension:
def on_field(self, field: Field | StrawberryField) -> Field | StrawberryField:
"""Called for each field, _MUST_ return valid field"""
return field

def before_wrap_dataclass(self, cls: type) -> None:
"""Called before class is wrapped as dataclass"""
return

def after_process(self, cls: type[WithStrawberryObjectDefinition]) -> type:
"""Called after entire process finishes"""
return cls

def create_object_definition(
self,
origin: type[Any],
name: str,
is_input: bool,
is_interface: bool,
interfaces: List[StrawberryObjectDefinition],
description: Optional[str],
directives: Optional[Sequence[object]],
extend: bool,
fields: List[StrawberryField],
is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]],
resolve_type: Optional[
Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str]
],
) -> StrawberryObjectDefinition:
"""Hook for creation of StrawberryObjectDefinition for __strawberry_definition__ attr"""

return StrawberryObjectDefinition(
name=name,
is_input=is_input,
is_interface=is_interface,
interfaces=interfaces,
description=description,
directives=directives,
origin=origin,
extend=extend,
fields=fields,
is_type_of=is_type_of,
resolve_type=resolve_type,
)
20 changes: 14 additions & 6 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import sys
from typing import Any, Dict, List, Type
from typing import TYPE_CHECKING, Any, Dict, List, Type

from strawberry.annotation import StrawberryAnnotation
from strawberry.exceptions import (
Expand All @@ -15,9 +15,14 @@
from strawberry.type import has_object_definition
from strawberry.unset import UNSET

if TYPE_CHECKING:
from strawberry.types.type_extension import TypeExtension


def _get_fields(
cls: Type[Any], original_type_annotations: Dict[str, Type[Any]]
cls: Type[Any],
extension: TypeExtension,
original_type_annotations: Dict[str, Type[Any]],
) -> List[StrawberryField]:
"""Get all the strawberry fields off a strawberry.type cls
Expand Down Expand Up @@ -78,7 +83,14 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
origins.setdefault(field.name, base)

# then we can proceed with finding the fields for the current class
extension_hook = extension.on_field
for field in dataclasses.fields(cls): # type: ignore
if field.name in original_type_annotations:
field.type = original_type_annotations[field.name]

# Extension field hook
field = extension_hook(field=field) # noqa: PLW2901

if isinstance(field, StrawberryField):
# Check that the field type is not Private
if is_private(field.type):
Expand Down Expand Up @@ -155,10 +167,6 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
assert_message = "Field must have a name by the time the schema is generated"
assert field_name is not None, assert_message

if field.name in original_type_annotations:
field.type = original_type_annotations[field.name]
field.type_annotation = StrawberryAnnotation(annotation=field.type)

# TODO: Raise exception if field_name already in fields
fields[field_name] = field

Expand Down

0 comments on commit 367f8af

Please sign in to comment.