Skip to content

Commit

Permalink
Fix #389
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertoPrevato committed Jul 2, 2023
1 parent fed357d commit 8e48342
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 40 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixes `TypeError` when writing a request without host header.
- Add support for `Pydantic` `v2`: meaning feature parity with support for
Pydantic v1 (generating OpenAPI Documentation).
- Add support for `Union` types in sub-properties of request handlers input and
output types, for generating OpenAPI Documentation, both using simple classes
and Pydantic [#389](https://github.com/Neoteroi/BlackSheep/issues/389)

## [2.0a7] - 2023-05-31 :corn:

Expand Down
10 changes: 0 additions & 10 deletions blacksheep/server/openapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,3 @@ def __init__(self, content_type: str) -> None:
"have unique type."
)
self.content_type = content_type


class UnsupportedUnionTypeException(DocumentationException):
def __init__(self, unsupported_type) -> None:
super().__init__(
f"Union types are not supported for automatic generation of "
"OpenAPI Documentation. The annotation that caused exception is: "
f"{unsupported_type}."
)
self.unsupported_type = unsupported_type
49 changes: 28 additions & 21 deletions blacksheep/server/openapi/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@
DocstringInfo,
get_handler_docstring_info,
)
from blacksheep.server.openapi.exceptions import (
DuplicatedContentTypeDocsException,
UnsupportedUnionTypeException,
)
from blacksheep.server.openapi.exceptions import DuplicatedContentTypeDocsException
from blacksheep.server.routing import Router

from ..application import Application
Expand Down Expand Up @@ -115,14 +112,13 @@ def check_union(object_type: Any) -> Tuple[bool, Any]:
if hasattr(object_type, "is_required"):
# Pydantic v2
return object_type.is_required(), object_type
# support only Union[None, Type] - that is equivalent of Optional[Type]
if type(None) not in object_type.__args__ or len(object_type.__args__) > 2:
raise UnsupportedUnionTypeException(object_type)

for possible_type in object_type.__args__:
if type(None) is possible_type:
continue
return True, possible_type
if type(None) in object_type.__args__ and len(object_type.__args__) == 2:
for possible_type in object_type.__args__:
if type(None) is possible_type:
continue
return True, possible_type
return type(None) not in object_type.__args__, object_type
return False, object_type


Expand Down Expand Up @@ -177,10 +173,6 @@ def get_type_fields(self, object_type, register_type) -> List[FieldInfo]:
return [FieldInfo(field.name, field.type) for field in fields(object_type)]


def _is_optional_any_of(any_of) -> bool:
return len(any_of) == 2 and any(item.get("type") == "null" for item in any_of)


def _try_is_subclass(object_type, check_type):
try:
return issubclass(object_type, check_type)
Expand Down Expand Up @@ -214,9 +206,6 @@ def _allow_none(self, field_info):
return not field_info.is_required()

def _handle_any_of(self, any_of, types_args, register_type):
# Currently only optional
assert _is_optional_any_of(any_of)

for item in any_of:
if "$ref" in item:
obj_type_name = item["$ref"].split("/")[-1]
Expand Down Expand Up @@ -591,15 +580,21 @@ def _handle_object_type(
required: List[str],
context_type_args: Optional[Dict[Any, Type]] = None,
) -> Reference:
type_name = self.get_type_name(object_type, context_type_args)
reference = self._register_schema(
return self._handle_object_type_schema(
object_type,
context_type_args,
Schema(
type=ValueType.OBJECT,
required=required or None,
properties=properties,
),
type_name,
)

def _handle_object_type_schema(
self, object_type, context_type_args, schema: Schema
):
type_name = self.get_type_name(object_type, context_type_args)
reference = self._register_schema(schema, type_name)
self._objects_references[object_type] = reference
self._objects_references[type_name] = reference
return reference
Expand Down Expand Up @@ -738,6 +733,18 @@ def _try_get_schema_for_generic(
) -> Optional[Reference]:
origin = get_origin(object_type)

if origin is Union:
schema = Schema(
ValueType.OBJECT,
any_of=[
self.get_schema_by_type(child_type, context_type_args)
for child_type in object_type.__args__
],
)
return self._handle_object_type_schema(
object_type, context_type_args, schema
)

required: List[str] = []
properties: Dict[str, Union[Schema, Reference]] = {}

Expand Down
233 changes: 224 additions & 9 deletions tests/test_openapi_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
ResponseInfo,
SecurityInfo,
)
from blacksheep.server.openapi.exceptions import (
DuplicatedContentTypeDocsException,
UnsupportedUnionTypeException,
)
from blacksheep.server.openapi.exceptions import DuplicatedContentTypeDocsException
from blacksheep.server.openapi.v3 import (
DataClassTypeHandler,
OpenAPIHandler,
Expand Down Expand Up @@ -296,11 +293,6 @@ async def example():
await app.start()


def test_raises_for_union_type(docs):
with pytest.raises(UnsupportedUnionTypeException):
docs.get_schema_by_type(Union[Foo, Ufo])


@pytest.mark.parametrize(
"annotation,expected_result",
[
Expand Down Expand Up @@ -3037,3 +3029,226 @@ def create_cat(self, cat: Cat) -> None:
- name: Parrots
""".strip()
)


@dataclass
class A:
a_prop: int


@dataclass
class B:
b_prop: str


@dataclass
class C:
c_prop: str


@dataclass
class D:
d_prop: float


@dataclass
class E:
e_prop: int


@dataclass
class F:
f_prop: str
f_prop2: A


@dataclass
class AnyOfTestClass:
sub_prop: Union[A, B, C]


@dataclass
class AnyOfResponseTestClass:
data: Union[D, E, F]


class APyd(BaseModel):
a_prop: int


class BPyd(BaseModel):
b_prop: str


class CPyd(BaseModel):
c_prop: str


class DPyd(BaseModel):
d_prop: float


class EPyd(BaseModel):
e_prop: int


class FPyd(BaseModel):
f_prop: str
f_prop2: APyd


class AnyOfTestClassPyd(BaseModel):
sub_prop: Union[APyd, BPyd, CPyd]


class AnyOfResponseTestClassPyd(BaseModel):
data: Union[DPyd, EPyd, FPyd]


@pytest.mark.asyncio
async def test_any_of_dataclasses(docs: OpenAPIHandler, serializer: Serializer):
app = get_app()
docs.bind_app(app)

@app.router.post("/one")
def one(data: AnyOfTestClass) -> AnyOfResponseTestClass:
...

await app.start()

yaml = serializer.to_yaml(docs.generate_documentation(app))

expected_fragments = [
"""
/one:
post:
responses:
'200':
description: Success response
content:
application/json:
schema:
$ref: '#/components/schemas/AnyOfResponseTestClass'
operationId: one
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/AnyOfTestClass'
required: true
""",
"""
D:
type: object
required:
- d_prop
properties:
d_prop:
type: number
format: float
nullable: false
""",
"""
UnionOfDAndEAndF:
type: object
anyOf:
- $ref: '#/components/schemas/D'
- $ref: '#/components/schemas/E'
- $ref: '#/components/schemas/F'
""",
"""
AnyOfResponseTestClass:
type: object
properties:
data:
$ref: '#/components/schemas/UnionOfDAndEAndF'
""",
"""
UnionOfAAndBAndC:
type: object
anyOf:
- $ref: '#/components/schemas/A'
- $ref: '#/components/schemas/B'
- $ref: '#/components/schemas/C'
""",
]

for fragment in expected_fragments:
assert fragment.strip() in yaml


@pytest.mark.asyncio
async def test_any_of_pydantic_models(docs: OpenAPIHandler, serializer: Serializer):
app = get_app()
docs.bind_app(app)

@app.router.post("/one")
def one(data: AnyOfTestClassPyd) -> AnyOfResponseTestClassPyd:
...

await app.start()

yaml = serializer.to_yaml(docs.generate_documentation(app))

expected_fragments = [
"""
/one:
post:
responses:
'200':
description: Success response
content:
application/json:
schema:
$ref: '#/components/schemas/AnyOfResponseTestClassPyd'
operationId: one
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/AnyOfTestClassPyd'
required: true
""",
"""
DPyd:
type: object
required:
- d_prop
properties:
d_prop:
type: number
format: float
nullable: false
""",
"""
AnyOfResponseTestClassPyd:
type: object
required:
- data
properties:
data:
title: Data
anyOf:
- $ref: '#/components/schemas/DPyd'
- $ref: '#/components/schemas/EPyd'
- $ref: '#/components/schemas/FPyd'
""",
"""
AnyOfTestClassPyd:
type: object
required:
- sub_prop
properties:
sub_prop:
title: Sub Prop
anyOf:
- $ref: '#/components/schemas/APyd'
- $ref: '#/components/schemas/BPyd'
- $ref: '#/components/schemas/CPyd'
""",
]

for fragment in expected_fragments:
assert fragment.strip() in yaml

0 comments on commit 8e48342

Please sign in to comment.