diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f9f9ccb..8cf21506 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixes `TypeError` when writing a request without host header. - Add support for `Pydantic` `v2`: meaning feature parity with support for Pydantic v1 (generating OpenAPI Documentation). +- Add support for `Union` types in sub-properties of request handlers input and + output types, for generating OpenAPI Documentation, both using simple classes + and Pydantic [#389](https://github.com/Neoteroi/BlackSheep/issues/389) ## [2.0a7] - 2023-05-31 :corn: diff --git a/blacksheep/server/openapi/exceptions.py b/blacksheep/server/openapi/exceptions.py index ba269d70..d91f5d17 100644 --- a/blacksheep/server/openapi/exceptions.py +++ b/blacksheep/server/openapi/exceptions.py @@ -10,13 +10,3 @@ def __init__(self, content_type: str) -> None: "have unique type." ) self.content_type = content_type - - -class UnsupportedUnionTypeException(DocumentationException): - def __init__(self, unsupported_type) -> None: - super().__init__( - f"Union types are not supported for automatic generation of " - "OpenAPI Documentation. The annotation that caused exception is: " - f"{unsupported_type}." - ) - self.unsupported_type = unsupported_type diff --git a/blacksheep/server/openapi/v3.py b/blacksheep/server/openapi/v3.py index 5ad42a30..07c69894 100644 --- a/blacksheep/server/openapi/v3.py +++ b/blacksheep/server/openapi/v3.py @@ -52,10 +52,7 @@ DocstringInfo, get_handler_docstring_info, ) -from blacksheep.server.openapi.exceptions import ( - DuplicatedContentTypeDocsException, - UnsupportedUnionTypeException, -) +from blacksheep.server.openapi.exceptions import DuplicatedContentTypeDocsException from blacksheep.server.routing import Router from ..application import Application @@ -115,14 +112,13 @@ def check_union(object_type: Any) -> Tuple[bool, Any]: if hasattr(object_type, "is_required"): # Pydantic v2 return object_type.is_required(), object_type - # support only Union[None, Type] - that is equivalent of Optional[Type] - if type(None) not in object_type.__args__ or len(object_type.__args__) > 2: - raise UnsupportedUnionTypeException(object_type) - for possible_type in object_type.__args__: - if type(None) is possible_type: - continue - return True, possible_type + if type(None) in object_type.__args__ and len(object_type.__args__) == 2: + for possible_type in object_type.__args__: + if type(None) is possible_type: + continue + return True, possible_type + return type(None) not in object_type.__args__, object_type return False, object_type @@ -177,10 +173,6 @@ def get_type_fields(self, object_type, register_type) -> List[FieldInfo]: return [FieldInfo(field.name, field.type) for field in fields(object_type)] -def _is_optional_any_of(any_of) -> bool: - return len(any_of) == 2 and any(item.get("type") == "null" for item in any_of) - - def _try_is_subclass(object_type, check_type): try: return issubclass(object_type, check_type) @@ -214,9 +206,6 @@ def _allow_none(self, field_info): return not field_info.is_required() def _handle_any_of(self, any_of, types_args, register_type): - # Currently only optional - assert _is_optional_any_of(any_of) - for item in any_of: if "$ref" in item: obj_type_name = item["$ref"].split("/")[-1] @@ -591,15 +580,21 @@ def _handle_object_type( required: List[str], context_type_args: Optional[Dict[Any, Type]] = None, ) -> Reference: - type_name = self.get_type_name(object_type, context_type_args) - reference = self._register_schema( + return self._handle_object_type_schema( + object_type, + context_type_args, Schema( type=ValueType.OBJECT, required=required or None, properties=properties, ), - type_name, ) + + def _handle_object_type_schema( + self, object_type, context_type_args, schema: Schema + ): + type_name = self.get_type_name(object_type, context_type_args) + reference = self._register_schema(schema, type_name) self._objects_references[object_type] = reference self._objects_references[type_name] = reference return reference @@ -738,6 +733,18 @@ def _try_get_schema_for_generic( ) -> Optional[Reference]: origin = get_origin(object_type) + if origin is Union: + schema = Schema( + ValueType.OBJECT, + any_of=[ + self.get_schema_by_type(child_type, context_type_args) + for child_type in object_type.__args__ + ], + ) + return self._handle_object_type_schema( + object_type, context_type_args, schema + ) + required: List[str] = [] properties: Dict[str, Union[Schema, Reference]] = {} diff --git a/tests/test_openapi_v3.py b/tests/test_openapi_v3.py index 52fe8ca1..07d9dffe 100644 --- a/tests/test_openapi_v3.py +++ b/tests/test_openapi_v3.py @@ -33,10 +33,7 @@ ResponseInfo, SecurityInfo, ) -from blacksheep.server.openapi.exceptions import ( - DuplicatedContentTypeDocsException, - UnsupportedUnionTypeException, -) +from blacksheep.server.openapi.exceptions import DuplicatedContentTypeDocsException from blacksheep.server.openapi.v3 import ( DataClassTypeHandler, OpenAPIHandler, @@ -296,11 +293,6 @@ async def example(): await app.start() -def test_raises_for_union_type(docs): - with pytest.raises(UnsupportedUnionTypeException): - docs.get_schema_by_type(Union[Foo, Ufo]) - - @pytest.mark.parametrize( "annotation,expected_result", [ @@ -3037,3 +3029,226 @@ def create_cat(self, cat: Cat) -> None: - name: Parrots """.strip() ) + + +@dataclass +class A: + a_prop: int + + +@dataclass +class B: + b_prop: str + + +@dataclass +class C: + c_prop: str + + +@dataclass +class D: + d_prop: float + + +@dataclass +class E: + e_prop: int + + +@dataclass +class F: + f_prop: str + f_prop2: A + + +@dataclass +class AnyOfTestClass: + sub_prop: Union[A, B, C] + + +@dataclass +class AnyOfResponseTestClass: + data: Union[D, E, F] + + +class APyd(BaseModel): + a_prop: int + + +class BPyd(BaseModel): + b_prop: str + + +class CPyd(BaseModel): + c_prop: str + + +class DPyd(BaseModel): + d_prop: float + + +class EPyd(BaseModel): + e_prop: int + + +class FPyd(BaseModel): + f_prop: str + f_prop2: APyd + + +class AnyOfTestClassPyd(BaseModel): + sub_prop: Union[APyd, BPyd, CPyd] + + +class AnyOfResponseTestClassPyd(BaseModel): + data: Union[DPyd, EPyd, FPyd] + + +@pytest.mark.asyncio +async def test_any_of_dataclasses(docs: OpenAPIHandler, serializer: Serializer): + app = get_app() + docs.bind_app(app) + + @app.router.post("/one") + def one(data: AnyOfTestClass) -> AnyOfResponseTestClass: + ... + + await app.start() + + yaml = serializer.to_yaml(docs.generate_documentation(app)) + + expected_fragments = [ + """ + /one: + post: + responses: + '200': + description: Success response + content: + application/json: + schema: + $ref: '#/components/schemas/AnyOfResponseTestClass' + operationId: one + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AnyOfTestClass' + required: true + """, + """ + D: + type: object + required: + - d_prop + properties: + d_prop: + type: number + format: float + nullable: false + """, + """ + UnionOfDAndEAndF: + type: object + anyOf: + - $ref: '#/components/schemas/D' + - $ref: '#/components/schemas/E' + - $ref: '#/components/schemas/F' + """, + """ + AnyOfResponseTestClass: + type: object + properties: + data: + $ref: '#/components/schemas/UnionOfDAndEAndF' + """, + """ + UnionOfAAndBAndC: + type: object + anyOf: + - $ref: '#/components/schemas/A' + - $ref: '#/components/schemas/B' + - $ref: '#/components/schemas/C' + """, + ] + + for fragment in expected_fragments: + assert fragment.strip() in yaml + + +@pytest.mark.asyncio +async def test_any_of_pydantic_models(docs: OpenAPIHandler, serializer: Serializer): + app = get_app() + docs.bind_app(app) + + @app.router.post("/one") + def one(data: AnyOfTestClassPyd) -> AnyOfResponseTestClassPyd: + ... + + await app.start() + + yaml = serializer.to_yaml(docs.generate_documentation(app)) + + expected_fragments = [ + """ + /one: + post: + responses: + '200': + description: Success response + content: + application/json: + schema: + $ref: '#/components/schemas/AnyOfResponseTestClassPyd' + operationId: one + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AnyOfTestClassPyd' + required: true + """, + """ + DPyd: + type: object + required: + - d_prop + properties: + d_prop: + type: number + format: float + nullable: false + """, + """ + AnyOfResponseTestClassPyd: + type: object + required: + - data + properties: + data: + title: Data + anyOf: + - $ref: '#/components/schemas/DPyd' + - $ref: '#/components/schemas/EPyd' + - $ref: '#/components/schemas/FPyd' + """, + """ + AnyOfTestClassPyd: + type: object + required: + - sub_prop + properties: + sub_prop: + title: Sub Prop + anyOf: + - $ref: '#/components/schemas/APyd' + - $ref: '#/components/schemas/BPyd' + - $ref: '#/components/schemas/CPyd' + """, + ] + + for fragment in expected_fragments: + assert fragment.strip() in yaml