Skip to content

Commit

Permalink
Object Extension Draft
Browse files Browse the repository at this point in the history
  • Loading branch information
Kitefiko committed Feb 11, 2024
1 parent c926395 commit 5cbe67a
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 10 deletions.
93 changes: 85 additions & 8 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
)
from typing_extensions import dataclass_transform

from graphql import GraphQLAbstractType, GraphQLResolveInfo

from .exceptions import (
MissingFieldAnnotationError,
MissingReturnAnnotationError,
ObjectIsNotClassError,
)
from .field import StrawberryField, field
from .type import get_object_definition
from .type import WithStrawberryObjectDefinition, get_object_definition
from .types.type_resolver import _get_fields
from .types.types import (
StrawberryObjectDefinition,
)
from .types.types import StrawberryObjectDefinition
from .utils.dataclasses import add_custom_init_fn
from .utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor
from .utils.str_converters import to_camel_case
Expand Down Expand Up @@ -124,6 +124,57 @@ def _wrap_dataclass(cls: Type[Any]):
return dclass


class StrawberryObjectBuilder:
def before_wrap_dataclass(self, cls: Type):
"""Modify class before any processing, one could add fields here"""

def on_field(self, field: dataclasses.Field[Any]) -> Any:
"""Modify field during _get_fields
translate from `auto` type
completely change field class to custom one
"""
return field

def create_object_definition(
self,
origin: Type[Any],
name: str,
is_input: bool,
is_interface: bool,
interfaces: List[StrawberryObjectDefinition],
description: Optional[str],
directives: Optional[Sequence[object]],
extend: bool,
fields: List[StrawberryField],
is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]],
resolve_type: Optional[
Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str]
],
) -> StrawberryObjectDefinition:
"""Posibility to use custom StrawberryObjectDefinition"""

return StrawberryObjectDefinition(
name=name,
is_input=is_input,
is_interface=is_interface,
interfaces=interfaces,
description=description,
directives=directives,
origin=origin,
extend=extend,
fields=fields,
is_type_of=is_type_of,
resolve_type=resolve_type,
)

def after_process(self, cls: Type[WithStrawberryObjectDefinition]):
"""Dont know tbh"""


_DEFAULT_STRAWBERRY_BUILDER = StrawberryObjectBuilder()


def _process_type(
cls: Type,
*,
Expand All @@ -133,15 +184,22 @@ def _process_type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
builder: Optional[StrawberryObjectBuilder] = None,
):
name = name or to_camel_case(cls.__name__)
# TODO: Can _process_type be changes so builder is required?
# breaks peoples code, but IS internal function
if builder is None:
builder = _DEFAULT_STRAWBERRY_BUILDER
elif not isinstance(builder, StrawberryObjectBuilder):
raise TypeError("Appropriate Strawberry Error about invalid Builder")

Check warning on line 194 in strawberry/object_type.py

View check run for this annotation

Codecov / codecov/patch

strawberry/object_type.py#L194

Added line #L194 was not covered by tests

interfaces = _get_interfaces(cls)
fields = _get_fields(cls)
name = name or to_camel_case(cls.__name__)
is_type_of = getattr(cls, "is_type_of", None)
resolve_type = getattr(cls, "resolve_type", None)
interfaces = _get_interfaces(cls)
fields = _get_fields(cls, builder)

cls.__strawberry_definition__ = StrawberryObjectDefinition(
cls.__strawberry_definition__ = builder.create_object_definition(
name=name,
is_input=is_input,
is_interface=is_interface,
Expand Down Expand Up @@ -182,6 +240,7 @@ def _process_type(

setattr(cls, field_.python_name, wrapped_func)

builder.after_process(cls)
return cls


Expand All @@ -198,6 +257,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
builder: Optional[StrawberryObjectBuilder] = None,
) -> T:
...

Expand All @@ -214,6 +274,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
builder: Optional[StrawberryObjectBuilder] = None,
) -> Callable[[T], T]:
...

Expand All @@ -227,6 +288,7 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
builder: Optional[StrawberryObjectBuilder] = None,
) -> Union[T, Callable[[T], T]]:
"""Annotates a class as a GraphQL type.
Expand All @@ -247,6 +309,12 @@ def wrap(cls: Type):
exc = ObjectIsNotClassError.type
raise exc(cls)

if builder is not None:
if not isinstance(builder, StrawberryObjectBuilder):
raise TypeError("Appropriate Strawberry Error about invalid Builder")

Check warning on line 314 in strawberry/object_type.py

View check run for this annotation

Codecov / codecov/patch

strawberry/object_type.py#L314

Added line #L314 was not covered by tests

builder.before_wrap_dataclass(cls)

Check warning on line 316 in strawberry/object_type.py

View check run for this annotation

Codecov / codecov/patch

strawberry/object_type.py#L316

Added line #L316 was not covered by tests

wrapped = _wrap_dataclass(cls)
return _process_type(
wrapped,
Expand All @@ -256,6 +324,7 @@ def wrap(cls: Type):
description=description,
directives=directives,
extend=extend,
builder=builder,
)

if cls is None:
Expand All @@ -274,6 +343,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
) -> T:
...

Expand All @@ -287,6 +357,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
) -> Callable[[T], T]:
...

Expand All @@ -297,6 +368,7 @@ def input(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
):
"""Annotates a class as a GraphQL Input type.
Example usage:
Expand All @@ -311,6 +383,7 @@ def input(
description=description,
directives=directives,
is_input=True,
builder=builder,
)


Expand All @@ -324,6 +397,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
) -> T:
...

Expand All @@ -337,6 +411,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
) -> Callable[[T], T]:
...

Expand All @@ -350,6 +425,7 @@ def interface(
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
builder: Optional[StrawberryObjectBuilder] = None,
):
"""Annotates a class as a GraphQL Interface.
Example usage:
Expand All @@ -364,6 +440,7 @@ def interface(
description=description,
directives=directives,
is_interface=True,
builder=builder,
)


Expand Down
18 changes: 16 additions & 2 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import sys
from typing import Dict, List, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Type

from strawberry.annotation import StrawberryAnnotation
from strawberry.exceptions import (
Expand All @@ -15,8 +15,13 @@
from strawberry.type import has_object_definition
from strawberry.unset import UNSET

if TYPE_CHECKING:
from strawberry.object_type import StrawberryObjectBuilder

def _get_fields(cls: Type) -> List[StrawberryField]:

def _get_fields(
cls: Type, builder: Optional[StrawberryObjectBuilder] = None
) -> List[StrawberryField]:
"""Get all the strawberry fields off a strawberry.type cls
This function returns a list of StrawberryFields (one for each field item), while
Expand Down Expand Up @@ -49,6 +54,12 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
passing a named function (i.e. not an anonymous lambda) to strawberry.field
(typically as a decorator).
"""
if builder is None:
# Builder would be required here
from strawberry.object_type import _DEFAULT_STRAWBERRY_BUILDER

builder = _DEFAULT_STRAWBERRY_BUILDER

fields: Dict[str, StrawberryField] = {}

# before trying to find any fields, let's first add the fields defined in
Expand Down Expand Up @@ -76,6 +87,9 @@ class if one is not set by either using an explicit strawberry.field(name=...) o

# then we can proceed with finding the fields for the current class
for field in dataclasses.fields(cls): # type: ignore
# Builder field hook
field = builder.on_field(field=field) # noqa: PLW2901

if isinstance(field, StrawberryField):
# Check that the field type is not Private
if is_private(field.type):
Expand Down

0 comments on commit 5cbe67a

Please sign in to comment.