From bebf21a006022cbd4947eaa44c4758b68be56826 Mon Sep 17 00:00:00 2001 From: Jacob Wegner Date: Sun, 26 May 2024 06:44:34 -0500 Subject: [PATCH] Fix AttributeError when using optimizer and prefetch_related (#533) * Add a test capturing the failure * Short-circuit when a queryset has been sliced There are several places in the codebase where qs._result_cache is evaluated. If a QuerySet has been sliced (e.g., prefetched with pagination in DjangoOptimizerExtension), it has been evaluated as a list and will not have the `_result_cache` attrib. * Update test to verify short-circuiting works as expected --- strawberry_django/optimizer.py | 4 ++ strawberry_django/permissions.py | 4 ++ strawberry_django/resolvers.py | 4 ++ tests/test_optimizer.py | 79 ++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 7497aad6..47df7606 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -694,6 +694,10 @@ def optimize( if isinstance(qs, BaseManager): qs = cast(QuerySet[_M], qs.all()) + if isinstance(qs, list): + # return sliced queryset as-is + return qs + # Avoid optimizing twice and also modify an already resolved queryset if ( get_queryset_config(qs).optimized or qs._result_cache is not None # type: ignore diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index 866d1b96..ef7bdb90 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -108,6 +108,10 @@ def filter_with_perms(qs: QuerySet[_M], info: Info) -> QuerySet[_M]: if not context.checkers or context.is_safe: return qs + if isinstance(qs, list): + # return sliced queryset as-is + return qs + # Do not do anything is results are cached if qs._result_cache is not None: # type: ignore set_perm_safe(False) diff --git a/strawberry_django/resolvers.py b/strawberry_django/resolvers.py index 2cde0ff4..77d1c974 100644 --- a/strawberry_django/resolvers.py +++ b/strawberry_django/resolvers.py @@ -27,6 +27,10 @@ def default_qs_hook(qs: models.QuerySet[_M]) -> models.QuerySet[_M]: + if isinstance(qs, list): + # return sliced queryset as-is + return qs + # FIXME: We probably won't need this anymore when we can use graphql-core 3.3.0+ # as its `complete_list_value` gives a preference to async iteration it if is # provided by the object. diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 28ed4056..7dd7acfa 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1025,3 +1025,82 @@ def test_query_nested_connection_with_filter(db, gql_client: GraphQLTestClient): assert { edge["node"]["id"] for edge in result["issuesWithFilters"]["edges"] } == expected + + +@pytest.mark.django_db(transaction=True) +def test_query_with_optimizer_paginated_prefetch(): + @strawberry_django.type(Milestone, pagination=True) + class MilestoneTypeWithNestedPrefetch: + @strawberry_django.field() + def name(self, info) -> str: + return self.name + + @strawberry_django.type( + Project, + ) + class ProjectTypeWithPrefetch: + @strawberry_django.field() + def name(self, info) -> str: + return self.name + + milestones: List[MilestoneTypeWithNestedPrefetch] + + milestone1 = MilestoneFactory.create() + project = milestone1.project + MilestoneFactory.create(project=project) + + @strawberry.type + class Query: + projects: List[ProjectTypeWithPrefetch] = strawberry_django.field() + + query1 = utils.generate_query(Query, enable_optimizer=False) + query_str = """ + query TestQuery { + projects { + name + milestones (pagination: {limit: 1}) { + name + } + } + } + """ + + # NOTE: The following assertion doesn't work because the + # DjangoOptimizerExtension instance is not the one within the + # generate_query wrapper + """ + assert DjangoOptimizerExtension.enabled.get() + """ + result1 = query1(query_str) + + assert isinstance(result1, ExecutionResult) + assert not result1.errors + assert result1.data == { + "projects": [ + { + "name": project.name, + "milestones": [ + { + "name": milestone1.name, + }, + ], + }, + ], + } + + query2 = utils.generate_query(Query, enable_optimizer=True) + result2 = query2(query_str) + + assert isinstance(result2, ExecutionResult) + assert result2.data == { + "projects": [ + { + "name": project.name, + "milestones": [ + { + "name": milestone1.name, + }, + ], + }, + ], + }