Skip to content

Commit

Permalink
The internal changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583520695
  • Loading branch information
Orbax Authors committed Nov 21, 2023
1 parent 726c630 commit b7a0e11
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class _NonTrackableMetadata:
Most fields of this dataclass are python containers (dict, list, tuple). If
they are attached a tf.Module directly, TF will turn them into TF trackable
wrappers (DictWrapper, ListWrapper, etc.), thus mutate their orginal PyTree
wrappers (DictWrapper, ListWrapper, etc.), thus mutate their original PyTree
def. Therefore, we create this dataclass to hold the metadata to avoid such
implicit conversion. See also
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-due-to-tfmodule-magic-conversion-during-attribute-assignment
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
jit_compile: Union[bool, Mapping[str, bool]] = True,
name: Optional[str] = None,
pspecs: Optional[PyTree] = None,
target_platform: Optional[str] = None,
):
"""JaxModule constructor.
Expand Down Expand Up @@ -116,6 +117,8 @@ def __init__(
same structure as ``params``. If set, the leaves of ``params`` must be
jax.Array, and ``JaxModule`` must be created within a DTensor export
context from ``with maybe_enable_dtensor_export_on(mesh)``.
target_platform: the target platform to export the model on; used to
generate the serve tags.
"""
if callable(apply_fn):
apply_fn_map: dict[str, ApplyFn] = {self.DEFAULT_METHOD_KEY: apply_fn}
Expand Down Expand Up @@ -184,6 +187,7 @@ def __init__(
var_trainable=trainable,
var_pspecs=pspecs,
)
self._target_platform = target_platform

def update_variables(self, params: PyTree):
"""Updates the variables associated with self.
Expand All @@ -199,7 +203,7 @@ def update_variables(self, params: PyTree):
raise ValueError(
'The PyTree structure of the updated parameters must be the same as'
f' that of the original parameters. Got new treedef: {treedef},'
f' orignal treedef: {self._nontrackable_metadata.var_treedef}'
f' original treedef: {self._nontrackable_metadata.var_treedef}'
)
new_vars = _jax_params_to_tf_variables(
params,
Expand Down Expand Up @@ -232,11 +236,18 @@ def native_serialization_platforms(self) -> list[str]:
native_serialization_platforms_list.append(
jax2tf_kwargs['native_serialization_platforms']
)
if len(jax2tf_kwargs['native_serialization_platforms']) > 1:
assert (
self._target_platform
), 'target_platform must be provided for multi-platform export.'

if not native_serialization_platforms_list:
return [jax2tf.jax_export.default_lowering_platform()]
else:
native_serialization_platforms = native_serialization_platforms_list[0]
if len(native_serialization_platforms_list[0]) > 1:
native_serialization_platforms = self._target_platform
else:
native_serialization_platforms = native_serialization_platforms_list[0]
for item in native_serialization_platforms_list:
assert item == native_serialization_platforms, (
'all ApplyFn must use exactly same native_serialization_platforms'
Expand Down

0 comments on commit b7a0e11

Please sign in to comment.