Skip to content

Commit

Permalink
🐛 fix: request_type to permission_getter
Browse files Browse the repository at this point in the history
  • Loading branch information
MDKVMT committed May 16, 2024
1 parent 51eb265 commit ec905b9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 12 deletions.
6 changes: 5 additions & 1 deletion graphemy/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def get_filter(
if None not in key:
for k in range(len(id)):
f.append(
id[k] if type(id[k]) == int else id[k][1:] if id[k].startswith('_') else getattr(model, id[k])
id[k]
if type(id[k]) == int
else id[k][1:]
if id[k].startswith('_')
else getattr(model, id[k])
== bindparam(
f'p{i}_{j}_{k}',
literal_execute=not isinstance(key[k], date),
Expand Down
2 changes: 1 addition & 1 deletion graphemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def __init_subclass__(cls):
for attr in to_remove:
del cls.__annotations__[attr]

async def permission_getter(info: Info) -> bool:
async def permission_getter(info: Info, request_type: str) -> bool:
return True
10 changes: 7 additions & 3 deletions graphemy/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict

import strawberry
from fastapi import Request
from fastapi import Request, Response
from graphql.error import GraphQLError
from graphql.error.graphql_error import format_error as format_graphql_error
from sqlalchemy.engine.base import Engine
Expand Down Expand Up @@ -165,8 +165,12 @@ def __init__(
strawberry.field(hello_world),
)

async def get_context(request: Request) -> dict:
context = await context_getter(request) if context_getter else {}
async def get_context(request: Request, response: Response) -> dict:
context = (
await context_getter(request, response)
if context_getter
else {}
)
for k, (func, return_class) in functions.items():
context[k] = GraphemyDataLoader(
load_fn=func
Expand Down
30 changes: 24 additions & 6 deletions graphemy/schemas/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,21 @@ class Schema:
attr.target if isinstance(attr.target, list) else [attr.target]
)
target = [returned_class.__tablename__ + '.' + t for t in target]
filtered_pairs = [(s, t) for s, t in zip(source, target) if not (isinstance(s, int) or s.startswith('_') or isinstance(t, int) or t.startswith('_'))]
source, target = zip(*filtered_pairs) if filtered_pairs else ([], [])

if len(source)>0 and len(target)>0:
filtered_pairs = [
(s, t)
for s, t in zip(source, target)
if not (
isinstance(s, int)
or s.startswith('_')
or isinstance(t, int)
or t.startswith('_')
)
]
source, target = (
zip(*filtered_pairs) if filtered_pairs else ([], [])
)

if len(source) > 0 and len(target) > 0:
cls.__table__.append_constraint(
ForeignKeyConstraint(source, target)
)
Expand Down Expand Up @@ -152,7 +163,14 @@ async def loader_func(
"""The dynamically generated DataLoader function."""
filter_args = vars(filters) if filters else None
source_value = (
[ attr if type(attr) == int else attr[1:] if attr.startswith('_') else getattr(self, attr) for attr in field_value.source]
[
attr
if type(attr) == int
else attr[1:]
if attr.startswith('_')
else getattr(self, attr)
for attr in field_value.source
]
if isinstance(field_value.source, list)
else getattr(self, field_value.source)
)
Expand Down Expand Up @@ -193,7 +211,7 @@ async def query(
self, info: Info, filters: filter | None = None
) -> list[cls.__strawberry_schema__]:
if not await cls.permission_getter(
info
info, 'query'
) or not await Setup.get_permission(cls, info.context, 'query'):
return []
data = await get_all(cls, filters, Setup.query_filter(cls, info))
Expand Down
2 changes: 1 addition & 1 deletion graphemy/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_auth(cls, module: 'Graphemy', request_type: str) -> BasePermission:
class IsAuthenticated(BasePermission):
async def has_permission(self, source, info, **kwargs) -> bool:
if not await module.permission_getter(
info
info, request_type
) or not await cls.get_permission(
module, info.context, request_type
):
Expand Down

0 comments on commit ec905b9

Please sign in to comment.