Skip to content

Commit

Permalink
Object Extension Draft
Browse files Browse the repository at this point in the history
  • Loading branch information
botberry authored and Kitefiko committed Feb 16, 2024
1 parent c926395 commit 4fa4129
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
34 changes: 27 additions & 7 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
)
from .field import StrawberryField, field
from .type import get_object_definition
from .types.type_extension import TypeExtension
from .types.type_resolver import _get_fields
from .types.types import (
StrawberryObjectDefinition,
)
from .types.types import StrawberryObjectDefinition
from .utils.dataclasses import add_custom_init_fn
from .utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor
from .utils.str_converters import to_camel_case
Expand Down Expand Up @@ -133,15 +132,20 @@ def _process_type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
):
name = name or to_camel_case(cls.__name__)
# TODO: Can _process_type be changes so builder is required?
# breaks peoples code, but IS internal function
if extension is None:
extension = TypeExtension()

interfaces = _get_interfaces(cls)
fields = _get_fields(cls)
name = name or to_camel_case(cls.__name__)
is_type_of = getattr(cls, "is_type_of", None)
resolve_type = getattr(cls, "resolve_type", None)
interfaces = _get_interfaces(cls)
fields = _get_fields(cls, extension)

cls.__strawberry_definition__ = StrawberryObjectDefinition(
cls.__strawberry_definition__ = extension.create_object_definition(
name=name,
is_input=is_input,
is_interface=is_interface,
Expand Down Expand Up @@ -182,6 +186,7 @@ def _process_type(

setattr(cls, field_.python_name, wrapped_func)

extension.after_process(cls)
return cls


Expand All @@ -198,6 +203,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> T:
...

Expand All @@ -214,6 +220,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]:
...

Expand All @@ -227,6 +234,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
extension: Optional[TypeExtension] = None,
) -> Union[T, Callable[[T], T]]:
"""Annotates a class as a GraphQL type.
Expand All @@ -236,6 +244,8 @@ def type(
>>> class X:
>>> field_abc: str = "ABC"
"""
if extension is None:
extension = TypeExtension()

def wrap(cls: Type):
if not inspect.isclass(cls):
Expand All @@ -247,6 +257,7 @@ def wrap(cls: Type):
exc = ObjectIsNotClassError.type
raise exc(cls)

extension.before_wrap_dataclass(cls)
wrapped = _wrap_dataclass(cls)
return _process_type(
wrapped,
Expand All @@ -256,6 +267,7 @@ def wrap(cls: Type):
description=description,
directives=directives,
extend=extend,
extension=extension,
)

if cls is None:
Expand All @@ -274,6 +286,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> T:
...

Expand All @@ -287,6 +300,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]:
...

Expand All @@ -297,6 +311,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
):
"""Annotates a class as a GraphQL Input type.
Example usage:
Expand All @@ -311,6 +326,7 @@ def input(
description=description,
directives=directives,
is_input=True,
extension=extension,
)


Expand All @@ -324,6 +340,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> T:
...

Expand All @@ -337,6 +354,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
) -> Callable[[T], T]:
...

Expand All @@ -350,6 +368,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extension: Optional[TypeExtension] = None,
):
"""Annotates a class as a GraphQL Interface.
Example usage:
Expand All @@ -364,6 +383,7 @@ def interface(
description=description,
directives=directives,
is_interface=True,
extension=extension,
)


Expand Down
56 changes: 56 additions & 0 deletions strawberry/types/type_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Type

from strawberry.types.types import StrawberryObjectDefinition

if TYPE_CHECKING:
from dataclasses import Field

from graphql import GraphQLAbstractType, GraphQLResolveInfo

from strawberry.field import StrawberryField
from strawberry.type import WithStrawberryObjectDefinition


class TypeExtension:
def before_wrap_dataclass(self, cls: Type) -> None:
pass

def on_field(self, field: Field | StrawberryField) -> Any:
return field

def create_object_definition(
self,
origin: Type[Any],
name: str,
is_input: bool,
is_interface: bool,
interfaces: List[StrawberryObjectDefinition],
description: Optional[str],
directives: Optional[Sequence[object]],
extend: bool,
fields: List[StrawberryField],
is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]],
resolve_type: Optional[
Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str]
],
) -> StrawberryObjectDefinition:
"""Posibility to use custom StrawberryObjectDefinition"""

return StrawberryObjectDefinition(
name=name,
is_input=is_input,
is_interface=is_interface,
interfaces=interfaces,
description=description,
directives=directives,
origin=origin,
extend=extend,
fields=fields,
is_type_of=is_type_of,
resolve_type=resolve_type,
)

def after_process(self, cls: Type[WithStrawberryObjectDefinition]) -> None:
pass
14 changes: 12 additions & 2 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import sys
from typing import Dict, List, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Type

from strawberry.annotation import StrawberryAnnotation
from strawberry.exceptions import (
Expand All @@ -15,8 +15,13 @@
from strawberry.type import has_object_definition
from strawberry.unset import UNSET

if TYPE_CHECKING:
from strawberry.types.type_extension import TypeExtension

def _get_fields(cls: Type) -> List[StrawberryField]:

def _get_fields(
cls: Type, extension: Optional[TypeExtension] = None
) -> List[StrawberryField]:
"""Get all the strawberry fields off a strawberry.type cls
This function returns a list of StrawberryFields (one for each field item), while
Expand Down Expand Up @@ -74,8 +79,13 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
if field.python_name in base.__annotations__:
origins.setdefault(field.name, base)

extension_hook = extension.on_field if extension else lambda field: field

# then we can proceed with finding the fields for the current class
for field in dataclasses.fields(cls): # type: ignore
# Extension field hook
field = extension_hook(field=field) # noqa: PLW2901 type: ignore

if isinstance(field, StrawberryField):
# Check that the field type is not Private
if is_private(field.type):
Expand Down

0 comments on commit 4fa4129

Please sign in to comment.