diff --git a/openapi_core/schema/schemas.py b/openapi_core/schema/schemas.py deleted file mode 100644 index 9cdc2e92..00000000 --- a/openapi_core/schema/schemas.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any -from typing import Dict - -from openapi_core.spec import Spec - - -def get_all_properties(schema: Spec) -> Dict[str, Any]: - properties = schema.get("properties", {}) - properties_dict = dict(list(properties.items())) - - if "allOf" not in schema: - return properties_dict - - for subschema in schema / "allOf": - subschema_props = get_all_properties(subschema) - properties_dict.update(subschema_props) - - return properties_dict diff --git a/openapi_core/unmarshalling/schemas/datatypes.py b/openapi_core/unmarshalling/schemas/datatypes.py index 96008373..e3335953 100644 --- a/openapi_core/unmarshalling/schemas/datatypes.py +++ b/openapi_core/unmarshalling/schemas/datatypes.py @@ -1,3 +1,4 @@ +from collections import namedtuple from typing import Dict from typing import Optional @@ -5,3 +6,7 @@ CustomFormattersDict = Dict[str, Formatter] FormattersDict = Dict[Optional[str], Formatter] +SchemaUnmarshaller = namedtuple( + "SchemaUnmarshaller", + ["schema", "unmarshaller"], +) diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 66184cba..ef897ab3 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -102,19 +102,11 @@ def create( if schema_type in self.COMPLEX_UNMARSHALLERS: complex_klass = self.COMPLEX_UNMARSHALLERS[schema_type] return complex_klass( - schema, validator, formatter, self, context=self.context + schema, validator, self, formatter, context=self.context ) klass = self.UNMARSHALLERS[schema_type] - return klass(schema, validator, formatter) - - def get_formatter( - self, type_format: str, default_formatters: FormattersDict - ) -> Optional[Formatter]: - try: - return self.custom_formatters[type_format] - except KeyError: - return default_formatters.get(type_format) + return klass(schema, validator, self, formatter) def get_validator(self, schema: Spec) -> Validator: resolver = schema.accessor.resolver # type: ignore diff --git a/openapi_core/unmarshalling/schemas/formatters.py b/openapi_core/unmarshalling/schemas/formatters.py index 47dd52b8..f9f6c982 100644 --- a/openapi_core/unmarshalling/schemas/formatters.py +++ b/openapi_core/unmarshalling/schemas/formatters.py @@ -8,20 +8,20 @@ class Formatter: def validate(self, value: Any) -> bool: return True - def unmarshal(self, value: Any) -> Any: + def format(self, value: Any) -> Any: return value @classmethod def from_callables( cls, - validate: Optional[Callable[[Any], Any]] = None, - unmarshal: Optional[Callable[[Any], Any]] = None, + validate_callable: Optional[Callable[[Any], Any]] = None, + format_callable: Optional[Callable[[Any], Any]] = None, ) -> "Formatter": attrs = {} - if validate is not None: - attrs["validate"] = staticmethod(validate) - if unmarshal is not None: - attrs["unmarshal"] = staticmethod(unmarshal) + if validate_callable is not None: + attrs["validate"] = staticmethod(validate_callable) + if format_callable is not None: + attrs["format"] = staticmethod(format_callable) klass: Type[Formatter] = type("Formatter", (cls,), attrs) return klass() diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index c2704a5c..bf4d0688 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -19,9 +19,9 @@ from openapi_schema_validator._types import is_string from openapi_core.extensions.models.factories import ModelPathFactory -from openapi_core.schema.schemas import get_all_properties from openapi_core.spec import Spec from openapi_core.unmarshalling.schemas.datatypes import FormattersDict +from openapi_core.unmarshalling.schemas.datatypes import SchemaUnmarshaller from openapi_core.unmarshalling.schemas.enums import UnmarshalContext from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, @@ -31,7 +31,11 @@ ) from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError +from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError from openapi_core.unmarshalling.schemas.exceptions import ValidateError +from openapi_core.unmarshalling.schemas.finders import AllOfSchemasFinder +from openapi_core.unmarshalling.schemas.finders import AnyOfSchemasFinder +from openapi_core.unmarshalling.schemas.finders import OneOfSchemaFinder from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.unmarshalling.schemas.util import format_byte from openapi_core.unmarshalling.schemas.util import format_date @@ -49,6 +53,7 @@ class BaseSchemaUnmarshaller: + TYPE = NotImplemented FORMATTERS: FormattersDict = { None: Formatter(), } @@ -57,18 +62,15 @@ def __init__( self, schema: Spec, validator: Validator, - formatter: Optional[Formatter], + unmarshallers_factory: "SchemaUnmarshallersFactory", + formatter: Optional[Formatter] = None, ): self.schema = schema self.validator = validator - self.format = schema.getkey("format") + self.unmarshallers_factory = unmarshallers_factory - if formatter is None: - if self.format not in self.FORMATTERS: - raise FormatterNotFoundError(self.format) - self.formatter = self.FORMATTERS[self.format] - else: - self.formatter = formatter + self.schema_format = schema.getkey("format") + self.formatter = formatter def __call__(self, value: Any) -> Any: if value is None: @@ -78,28 +80,86 @@ def __call__(self, value: Any) -> Any: return self.unmarshal(value) - def _formatter_validate(self, value: Any) -> None: - result = self.formatter.validate(value) + def _clone(self, schema: Spec) -> "ObjectUnmarshaller": + return self.unmarshallers_factory.create( + schema, + type_override=self.TYPE, + ) + + def _find_one_of_schema(self, value: Any) -> Optional[SchemaUnmarshaller]: + finder = OneOfSchemaFinder(self.schema, self.unmarshallers_factory) + return finder.find(value) + + def _find_any_of_schemas(self, value: Any) -> Iterable[SchemaUnmarshaller]: + finder = AnyOfSchemasFinder(self.schema, self.unmarshallers_factory) + yield from finder.find(value) + + def _find_all_of_schemas(self, value: Any) -> Iterable[SchemaUnmarshaller]: + finder = AllOfSchemasFinder(self.schema, self.unmarshallers_factory) + yield from finder.find(value) + + def _get_formatter(self) -> Formatter: + if self.formatter is not None: + return self.formatter + + if self.schema_format not in self.FORMATTERS: + raise FormatterNotFoundError(self.schema_format) + return self.FORMATTERS[self.schema_format] + + def _get_best_formatter(self, value: Any) -> Formatter: + if self.formatter is not None: + return self.formatter + + if self.schema_format is None: + for schema, unmarshaller in self._find_all_of_schemas(value): + if "format" in schema: + return unmarshaller._get_formatter() + + one_of = self._find_one_of_schema(value) + if one_of is not None: + if "format" in one_of.schema: + return one_of.unmarshaller._get_formatter() + + for schema, unmarshaller in self._find_any_of_schemas(value): + if "format" in schema: + return unmarshaller._get_formatter() + + if self.schema_format not in self.FORMATTERS: + raise FormatterNotFoundError(self.schema_format) + return self.FORMATTERS[self.schema_format] + + def _validate_format(self, value: Any) -> None: + if self.formatter is not None: + formatter = self.formatter + else: + if self.schema_format not in self.FORMATTERS: + raise FormatterNotFoundError(self.schema_format) + formatter = self.FORMATTERS[self.schema_format] + + result = formatter.validate(value) if not result: - schema_type = self.schema.getkey("type", "any") - raise InvalidSchemaValue(value, schema_type) + raise InvalidSchemaValue(value, self.TYPE) + + def format(self, value: Any) -> Any: + formatter = self._get_best_formatter(value) + try: + return formatter.format(value) + except ValueError as exc: + raise InvalidSchemaFormatValue(value, self.schema_format, exc) def validate(self, value: Any) -> None: errors_iter = self.validator.iter_errors(value) errors = tuple(errors_iter) if errors: - schema_type = self.schema.getkey("type", "any") - raise InvalidSchemaValue(value, schema_type, schema_errors=errors) + raise InvalidSchemaValue(value, self.TYPE, schema_errors=errors) def unmarshal(self, value: Any) -> Any: - try: - return self.formatter.unmarshal(value) - except ValueError as exc: - raise InvalidSchemaFormatValue(value, self.format, exc) + return self.format(value) class StringUnmarshaller(BaseSchemaUnmarshaller): + TYPE = "string" FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_string, None), str), "password": Formatter.from_callables( @@ -126,6 +186,7 @@ class StringUnmarshaller(BaseSchemaUnmarshaller): class IntegerUnmarshaller(BaseSchemaUnmarshaller): + TYPE = "integer" FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_integer, None), int), "int32": Formatter.from_callables( @@ -139,6 +200,7 @@ class IntegerUnmarshaller(BaseSchemaUnmarshaller): class NumberUnmarshaller(BaseSchemaUnmarshaller): + TYPE = "number" FORMATTERS: FormattersDict = { None: Formatter.from_callables( partial(is_number, None), format_number @@ -154,6 +216,7 @@ class NumberUnmarshaller(BaseSchemaUnmarshaller): class BooleanUnmarshaller(BaseSchemaUnmarshaller): + TYPE = "boolean" FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_bool, None), forcebool), } @@ -171,17 +234,17 @@ def __init__( self, schema: Spec, validator: Validator, - formatter: Optional[Formatter], unmarshallers_factory: "SchemaUnmarshallersFactory", + formatter: Optional[Formatter] = None, context: Optional[UnmarshalContext] = None, ): - super().__init__(schema, validator, formatter) - self.unmarshallers_factory = unmarshallers_factory + super().__init__(schema, validator, unmarshallers_factory, formatter) self.context = context class ArrayUnmarshaller(ComplexUnmarshaller): + TYPE = "array" FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_array, None), list), } @@ -192,8 +255,8 @@ def items_unmarshaller(self) -> "BaseSchemaUnmarshaller": items_schema = self.schema.get("items", Spec.from_dict({})) return self.unmarshallers_factory.create(items_schema) - def __call__(self, value: Any) -> Optional[List[Any]]: - value = super().__call__(value) + def unmarshal(self, value: Any) -> Optional[List[Any]]: + value = super().unmarshal(value) if value is None and self.schema.getkey("nullable", False): return None return list(map(self.items_unmarshaller, value)) @@ -201,6 +264,7 @@ def __call__(self, value: Any) -> Optional[List[Any]]: class ObjectUnmarshaller(ComplexUnmarshaller): + TYPE = "object" FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_object, None), dict), } @@ -210,21 +274,16 @@ def object_class_factory(self) -> ModelPathFactory: return ModelPathFactory() def unmarshal(self, value: Any) -> Any: - properties = self.unmarshal_raw(value) + properties = self.format(value) fields: Iterable[str] = properties and properties.keys() or [] object_class = self.object_class_factory.create(self.schema, fields) return object_class(**properties) - def unmarshal_raw(self, value: Any) -> Any: - try: - value = self.formatter.unmarshal(value) - except ValueError as exc: - schema_format = self.schema.getkey("format") - raise InvalidSchemaFormatValue(value, schema_format, exc) - else: - return self._unmarshal_object(value) + def format(self, value: Any) -> Any: + formatted = super().format(value) + return self._unmarshal_properties(formatted) def _clone(self, schema: Spec) -> "ObjectUnmarshaller": return cast( @@ -232,48 +291,25 @@ def _clone(self, schema: Spec) -> "ObjectUnmarshaller": self.unmarshallers_factory.create(schema, "object"), ) - def _unmarshal_object(self, value: Any) -> Any: + def _unmarshal_properties(self, value: Any) -> Any: properties = {} - if "oneOf" in self.schema: - one_of_properties = None - for one_of_schema in self.schema / "oneOf": - try: - unmarshalled = self._clone(one_of_schema).unmarshal_raw( - value - ) - except (UnmarshalError, ValueError): - pass - else: - if one_of_properties is not None: - log.warning("multiple valid oneOf schemas found") - continue - one_of_properties = unmarshalled - - if one_of_properties is None: - log.warning("valid oneOf schema not found") - else: - properties.update(one_of_properties) - - elif "anyOf" in self.schema: - any_of_properties = None - for any_of_schema in self.schema / "anyOf": - try: - unmarshalled = self._clone(any_of_schema).unmarshal_raw( - value - ) - except (UnmarshalError, ValueError): - pass - else: - any_of_properties = unmarshalled - break - - if any_of_properties is None: - log.warning("valid anyOf schema not found") - else: - properties.update(any_of_properties) + for _, unmarshaller in self._find_all_of_schemas(value): + all_of_properties = unmarshaller.format(value) + properties.update(all_of_properties) + + one_of = self._find_one_of_schema(value) + if one_of is not None: + one_of_properties = one_of.unmarshaller.format(value) + properties.update(one_of_properties) + + for _, unmarshaller in self._find_any_of_schemas(value): + any_of_properties = unmarshaller.format(value) + properties.update(any_of_properties) - for prop_name, prop in get_all_properties(self.schema).items(): + schema_properties = self.schema.get("properties", {}) + schema_properties_dict = dict(list(schema_properties.items())) + for prop_name, prop in schema_properties_dict.items(): read_only = prop.getkey("readOnly", False) if self.context == UnmarshalContext.REQUEST and read_only: continue @@ -335,6 +371,7 @@ def unmarshal(self, value: Any) -> Any: class AnyUnmarshaller(ComplexUnmarshaller): + TYPE = "any" SCHEMA_TYPES_ORDER = [ "object", "array", @@ -344,77 +381,44 @@ class AnyUnmarshaller(ComplexUnmarshaller): "string", ] - def unmarshal(self, value: Any) -> Any: - one_of_schema = self._get_one_of_schema(value) - if one_of_schema: - return self.unmarshallers_factory.create(one_of_schema)(value) - - any_of_schema = self._get_any_of_schema(value) - if any_of_schema: - return self.unmarshallers_factory.create(any_of_schema)(value) - - all_of_schema = self._get_all_of_schema(value) - if all_of_schema: - return self.unmarshallers_factory.create(all_of_schema)(value) + _best_unmarshaller: Optional[BaseSchemaUnmarshaller] = None + def _get_best_unmarshaller(self, value: Any) -> BaseSchemaUnmarshaller: for schema_type in self.SCHEMA_TYPES_ORDER: unmarshaller = self.unmarshallers_factory.create( self.schema, type_override=schema_type ) # validate with validator of formatter (usualy type validator) try: - unmarshaller._formatter_validate(value) + unmarshaller._validate_format(value) except ValidateError: continue else: - return unmarshaller(value) + return unmarshaller - log.warning("failed to unmarshal any type") - return value + raise UnmarshallerError("Unmarshaller not found for any type") - def _get_one_of_schema(self, value: Any) -> Optional[Spec]: - if "oneOf" not in self.schema: - return None + def get_best_unmarshaller(self, value: Any) -> Any: + if self._best_unmarshaller is None: + self._best_unmarshaller = self._get_best_unmarshaller(value) + return self._best_unmarshaller - one_of_schemas = self.schema / "oneOf" - for subschema in one_of_schemas: - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None + def format(self, value: Any) -> Any: + unmarshaller = self.get_best_unmarshaller(value) + return unmarshaller.format(value) - def _get_any_of_schema(self, value: Any) -> Optional[Spec]: - if "anyOf" not in self.schema: - return None + def unmarshal(self, value: Any) -> Any: + # one_of_schema = self._get_one_of_schema(value) + # if one_of_schema: + # return self.unmarshallers_factory.create(one_of_schema)(value) - any_of_schemas = self.schema / "anyOf" - for subschema in any_of_schemas: - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None + # any_of_schema = self._get_any_of_schema(value) + # if any_of_schema: + # return self.unmarshallers_factory.create(any_of_schema)(value) - def _get_all_of_schema(self, value: Any) -> Optional[Spec]: - if "allOf" not in self.schema: - return None + # all_of_schema = self._get_all_of_schema(value) + # if all_of_schema: + # return self.unmarshallers_factory.create(all_of_schema)(value) - all_of_schemas = self.schema / "allOf" - for subschema in all_of_schemas: - if "type" not in subschema: - continue - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None + unmarshaller = self.get_best_unmarshaller(value) + return unmarshaller.unmarshal(value) diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 3ce50db4..0a16d7f0 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -719,6 +719,17 @@ def test_schema_any(self, unmarshaller_factory): spec = Spec.from_dict(schema) assert unmarshaller_factory(spec)("string") == "string" + def test_schema_any_object(self, unmarshaller_factory): + schema = { + "required": ["someint"], + "properties": {"someint": {"type": "integer"}}, + } + spec = Spec.from_dict(schema) + result = unmarshaller_factory(spec)({"someint": 1}) + + assert is_dataclass(result) + assert result.someint == 1 + @pytest.mark.parametrize( "value", [