Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fork a small amount of Orbax code into Pax to deal with writing "aggregate" files, as Orbax will soon lose this ability. #869

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
RESTORE_TYPE_NONE = 'None'
RESTORE_TYPE_DICT = 'Dict'
RESTORE_TYPE_LIST = 'List'
RESTORE_TYPE_TUPLE = 'Tuple'

_DEFAULT_DRIVER = 'file'
_PROCESS_SUBDIR_PREFIX = 'ocdbt.process_'
Expand Down Expand Up @@ -92,10 +93,13 @@ async def _assert_parameter_files_exist(


def get_empty_value_typestr(value: Any) -> str:
"""Get a typestr for an empty value."""
if not utils.is_supported_empty_aggregation_type(value):
raise ValueError(f'{value} is not a supported empty aggregation type.')
if isinstance(value, list):
return RESTORE_TYPE_LIST
elif isinstance(value, tuple):
return RESTORE_TYPE_TUPLE
elif isinstance(value, dict):
return RESTORE_TYPE_DICT
elif isinstance(value, type(None)):
Expand All @@ -107,6 +111,7 @@ def get_empty_value_typestr(value: Any) -> str:
def is_empty_typestr(typestr: str) -> bool:
return (
typestr == RESTORE_TYPE_LIST
or typestr == RESTORE_TYPE_TUPLE
or typestr == RESTORE_TYPE_DICT
or typestr == RESTORE_TYPE_NONE
)
Expand All @@ -115,6 +120,8 @@ def is_empty_typestr(typestr: str) -> bool:
def get_empty_value_from_typestr(typestr: str) -> Any:
if typestr == RESTORE_TYPE_LIST:
return []
elif typestr == RESTORE_TYPE_TUPLE:
return ()
elif typestr == RESTORE_TYPE_DICT:
return {}
elif typestr == RESTORE_TYPE_NONE:
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def all_leaves_are_placeholders(tree: PyTree) -> bool:
def is_supported_empty_aggregation_type(value: Any) -> bool:
"""Determines if the *empty* value is supported for aggregation."""
# Check isinstance first to avoid `not` checks on jax.Arrays (raises error).
return isinstance(value, (dict, list, type(None))) and not value
return isinstance(value, (dict, list, tuple, type(None))) and not value


def is_supported_aggregation_type(value: Any) -> bool:
Expand Down