Skip to content

Commit

Permalink
handle v2 BaseModel in visit_collection
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed May 10, 2024
1 parent cb84f4a commit 04bf3d5
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This module contains the definition for futures as well as utilities for resolving
futures in nested data structures.
"""

import asyncio
import warnings
from functools import partial
Expand Down
57 changes: 47 additions & 10 deletions src/prefect/utilities/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Iterable,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Expand All @@ -32,6 +33,7 @@

if HAS_PYDANTIC_V2:
import pydantic.v1 as pydantic
from pydantic import BaseModel as V2BaseModel
else:
import pydantic

Expand Down Expand Up @@ -78,7 +80,7 @@ def __repr__(self) -> str:


def dict_to_flatdict(
dct: Dict[KT, Union[Any, Dict[KT, Any]]], _parent: Tuple[KT, ...] = None
dct: Dict[KT, Union[Any, Dict[KT, Any]]], _parent: Optional[Tuple[KT, ...]] = None
) -> Dict[Tuple[KT, ...], Any]:
"""Converts a (nested) dictionary to a flattened representation.
Expand Down Expand Up @@ -121,11 +123,11 @@ def flatdict_to_dict(
typ = type(dct)
result = cast(Dict[KT, Union[VT, Dict[KT, VT]]], typ())
for key_tuple, value in dct.items():
current_dict = result
current_dict: Dict = result
for prefix_key in key_tuple[:-1]:
# Build nested dictionaries up for the current key tuple
# Use `setdefault` in case the nested dict has already been created
current_dict = current_dict.setdefault(prefix_key, typ()) # type: ignore
current_dict = current_dict.setdefault(prefix_key, typ())
# Set the value
current_dict[key_tuple[-1]] = value

Expand Down Expand Up @@ -166,7 +168,7 @@ def listrepr(objs: Iterable[Any], sep: str = " ") -> str:
def extract_instances(
objects: Iterable,
types: Union[Type[T], Tuple[Type[T], ...]] = object,
) -> Union[List[T], Dict[Type[T], T]]:
) -> Union[List[T], Mapping[Type[T], T]]:
"""
Extract objects from a file and returns a dict of type -> instances
Expand All @@ -178,21 +180,22 @@ def extract_instances(
If a single type is given: a list of instances of that type
If a tuple of types is given: a mapping of type to a list of instances
"""
types = ensure_iterable(types)
types_: Iterable[Type[T]] = ensure_iterable(types)

# Create a mapping of type -> instance from the exec values
ret = defaultdict(list)

for o in objects:
# We iterate here so that the key is the passed type rather than type(o)
for type_ in types:
for type_ in types_:
if isinstance(o, type_):
ret[type_].append(o)

if len(types) == 1:
return ret[types[0]]
list_of_types = list(types_)
if len(list_of_types) == 1:
return ret[list_of_types[0]]

return ret
return ret # type: ignore


def batched_iterable(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
Expand Down Expand Up @@ -288,7 +291,7 @@ def visit_nested(expr):

def visit_expression(expr):
if context is not None:
return visit_fn(expr, context)
return visit_fn(expr, context) # type: ignore
else:
return visit_fn(expr)

Expand Down Expand Up @@ -380,6 +383,40 @@ def visit_expression(expr):
else:
result = None

elif isinstance(expr, V2BaseModel):
expr = cast(V2BaseModel, expr)
model_fields = {
f
for f in expr.model_fields_set.union(expr.model_fields)
if hasattr(expr, f)
}
items = [visit_nested(getattr(expr, key)) for key in model_fields]

if return_data:
# Collect fields with aliases so reconstruction can use the correct field name
aliases = {
key: info.alias
for key, info in expr.model_fields.items()
if info.alias is not None
}

model_data = {
aliases.get(key) or key: value
for key, value in zip(model_fields, items)
}

# Create a new instance of the model using the `create_model` function

model_instance = typ(**model_data)

# Restore private attributes after creating the new model
for attr in expr.__private_attributes__:
setattr(model_instance, attr, getattr(expr, attr))

result = model_instance
else:
result = None

else:
result = result if return_data else None

Expand Down
18 changes: 17 additions & 1 deletion tests/utilities/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if HAS_PYDANTIC_V2:
import pydantic.v1 as pydantic
from pydantic import BaseModel as V2BaseModel
else:
import pydantic

Expand Down Expand Up @@ -44,7 +45,7 @@ def test_autoenum_repr(self):
assert repr(Color.RED) == str(Color.RED) == "Color.RED"

def test_autoenum_can_be_json_serialized_with_default_encoder(self):
json.dumps(Color.RED) == "RED"
assert json.dumps(Color.RED) == "RED"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -502,6 +503,21 @@ def visit(expr, context):
# Only the first two items should be visited
assert result == [2, 3, [3, [4, 5, 6]]]

@pytest.mark.skipif(not HAS_PYDANTIC_V2, reason="Only runs with Pydantic v2")
def test_visit_collection_v2_base_model(self, capsys):
class V2Model(V2BaseModel):
x: int
y: int

input = V2Model(x=1, y=2)
result = visit_collection(
input, visit_fn=negative_even_numbers, return_data=True
)
assert result == V2Model(x=1, y=-2)
out = capsys.readouterr().out
assert "Function called on 1" in out
assert "Function called on 2" in out


class TestRemoveKeys:
def test_remove_single_key(self):
Expand Down

0 comments on commit 04bf3d5

Please sign in to comment.