You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
8.9 KiB
210 lines
8.9 KiB
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import TypedDict
|
|
|
|
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
|
|
from typing_extensions import TypeAlias
|
|
|
|
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
|
|
|
|
|
|
class GatherResult(TypedDict):
|
|
"""Schema traversing result."""
|
|
|
|
collected_references: dict[str, DefinitionReferenceSchema | None]
|
|
"""The collected definition references.
|
|
|
|
If a definition reference schema can be inlined, it means that there is
|
|
only one in the whole core schema. As such, it is stored as the value.
|
|
Otherwise, the value is set to `None`.
|
|
"""
|
|
|
|
deferred_discriminator_schemas: list[CoreSchema]
|
|
"""The list of core schemas having the discriminator application deferred."""
|
|
|
|
|
|
class MissingDefinitionError(LookupError):
|
|
"""A reference was pointing to a non-existing core schema."""
|
|
|
|
def __init__(self, schema_reference: str, /) -> None:
|
|
self.schema_reference = schema_reference
|
|
|
|
|
|
@dataclass
|
|
class GatherContext:
|
|
"""The current context used during core schema traversing.
|
|
|
|
Context instances should only be used during schema traversing.
|
|
"""
|
|
|
|
definitions: dict[str, CoreSchema]
|
|
"""The available definitions."""
|
|
|
|
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
|
|
"""The list of core schemas having the discriminator application deferred.
|
|
|
|
Internally, these core schemas have a specific key set in the core metadata dict.
|
|
"""
|
|
|
|
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
|
|
"""The collected definition references.
|
|
|
|
If a definition reference schema can be inlined, it means that there is
|
|
only one in the whole core schema. As such, it is stored as the value.
|
|
Otherwise, the value is set to `None`.
|
|
|
|
During schema traversing, definition reference schemas can be added as candidates, or removed
|
|
(by setting the value to `None`).
|
|
"""
|
|
|
|
|
|
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
|
|
meta = schema.get('metadata')
|
|
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
|
|
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
|
|
|
|
|
|
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
|
|
schema_ref = def_ref_schema['schema_ref']
|
|
|
|
if schema_ref not in ctx.collected_references:
|
|
definition = ctx.definitions.get(schema_ref)
|
|
if definition is None:
|
|
raise MissingDefinitionError(schema_ref)
|
|
|
|
# The `'definition-ref'` schema was only encountered once, make it
|
|
# a candidate to be inlined:
|
|
ctx.collected_references[schema_ref] = def_ref_schema
|
|
traverse_schema(definition, ctx)
|
|
if 'serialization' in def_ref_schema:
|
|
traverse_schema(def_ref_schema['serialization'], ctx)
|
|
traverse_metadata(def_ref_schema, ctx)
|
|
else:
|
|
# The `'definition-ref'` schema was already encountered, meaning
|
|
# the previously encountered schema (and this one) can't be inlined:
|
|
ctx.collected_references[schema_ref] = None
|
|
|
|
|
|
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
|
|
# TODO When we drop 3.9, use a match statement to get better type checking and remove
|
|
# file-level type ignore.
|
|
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
|
|
schema_type = schema['type']
|
|
|
|
if schema_type == 'definition-ref':
|
|
traverse_definition_ref(schema, context)
|
|
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
|
|
return
|
|
elif schema_type == 'definitions':
|
|
traverse_schema(schema['schema'], context)
|
|
for definition in schema['definitions']:
|
|
traverse_schema(definition, context)
|
|
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
|
|
if 'items_schema' in schema:
|
|
traverse_schema(schema['items_schema'], context)
|
|
elif schema_type == 'tuple':
|
|
if 'items_schema' in schema:
|
|
for s in schema['items_schema']:
|
|
traverse_schema(s, context)
|
|
elif schema_type == 'dict':
|
|
if 'keys_schema' in schema:
|
|
traverse_schema(schema['keys_schema'], context)
|
|
if 'values_schema' in schema:
|
|
traverse_schema(schema['values_schema'], context)
|
|
elif schema_type == 'union':
|
|
for choice in schema['choices']:
|
|
if isinstance(choice, tuple):
|
|
traverse_schema(choice[0], context)
|
|
else:
|
|
traverse_schema(choice, context)
|
|
elif schema_type == 'tagged-union':
|
|
for v in schema['choices'].values():
|
|
traverse_schema(v, context)
|
|
elif schema_type == 'chain':
|
|
for step in schema['steps']:
|
|
traverse_schema(step, context)
|
|
elif schema_type == 'lax-or-strict':
|
|
traverse_schema(schema['lax_schema'], context)
|
|
traverse_schema(schema['strict_schema'], context)
|
|
elif schema_type == 'json-or-python':
|
|
traverse_schema(schema['json_schema'], context)
|
|
traverse_schema(schema['python_schema'], context)
|
|
elif schema_type in {'model-fields', 'typed-dict'}:
|
|
if 'extras_schema' in schema:
|
|
traverse_schema(schema['extras_schema'], context)
|
|
if 'computed_fields' in schema:
|
|
for s in schema['computed_fields']:
|
|
traverse_schema(s, context)
|
|
for s in schema['fields'].values():
|
|
traverse_schema(s, context)
|
|
elif schema_type == 'dataclass-args':
|
|
if 'computed_fields' in schema:
|
|
for s in schema['computed_fields']:
|
|
traverse_schema(s, context)
|
|
for s in schema['fields']:
|
|
traverse_schema(s, context)
|
|
elif schema_type == 'arguments':
|
|
for s in schema['arguments_schema']:
|
|
traverse_schema(s['schema'], context)
|
|
if 'var_args_schema' in schema:
|
|
traverse_schema(schema['var_args_schema'], context)
|
|
if 'var_kwargs_schema' in schema:
|
|
traverse_schema(schema['var_kwargs_schema'], context)
|
|
elif schema_type == 'arguments-v3':
|
|
for s in schema['arguments_schema']:
|
|
traverse_schema(s['schema'], context)
|
|
elif schema_type == 'call':
|
|
traverse_schema(schema['arguments_schema'], context)
|
|
if 'return_schema' in schema:
|
|
traverse_schema(schema['return_schema'], context)
|
|
elif schema_type == 'computed-field':
|
|
traverse_schema(schema['return_schema'], context)
|
|
elif schema_type == 'function-before':
|
|
if 'schema' in schema:
|
|
traverse_schema(schema['schema'], context)
|
|
if 'json_schema_input_schema' in schema:
|
|
traverse_schema(schema['json_schema_input_schema'], context)
|
|
elif schema_type == 'function-plain':
|
|
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
|
if 'return_schema' in schema:
|
|
traverse_schema(schema['return_schema'], context)
|
|
if 'json_schema_input_schema' in schema:
|
|
traverse_schema(schema['json_schema_input_schema'], context)
|
|
elif schema_type == 'function-wrap':
|
|
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
|
if 'return_schema' in schema:
|
|
traverse_schema(schema['return_schema'], context)
|
|
if 'schema' in schema:
|
|
traverse_schema(schema['schema'], context)
|
|
if 'json_schema_input_schema' in schema:
|
|
traverse_schema(schema['json_schema_input_schema'], context)
|
|
else:
|
|
if 'schema' in schema:
|
|
traverse_schema(schema['schema'], context)
|
|
|
|
if 'serialization' in schema:
|
|
traverse_schema(schema['serialization'], context)
|
|
traverse_metadata(schema, context)
|
|
|
|
|
|
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
|
|
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
|
|
|
|
During the core schema traversing, any `'definition-ref'` schema is:
|
|
|
|
- Validated: the reference must point to an existing definition. If this is not the case, a
|
|
`MissingDefinitionError` exception is raised.
|
|
- Stored in the context: the actual reference is stored in the context. Depending on whether
|
|
the `'definition-ref'` schema is encountered more that once, the schema itself is also
|
|
saved in the context to be inlined (i.e. replaced by the definition it points to).
|
|
"""
|
|
context = GatherContext(definitions)
|
|
traverse_schema(schema, context)
|
|
|
|
return {
|
|
'collected_references': context.collected_references,
|
|
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
|
|
}
|