Skip to content

Commit

Permalink
Modify _write_metadata_file to async.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636942197
  • Loading branch information
liangyaning33 authored and Orbax Authors committed May 24, 2024
1 parent 7713642 commit a09e829
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Modify `_write_metadata_file` to Async.

## [0.5.14] - 2024-05-23

### Changed
Expand Down
42 changes: 28 additions & 14 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import collections
import dataclasses
import json
import os
import time
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -572,10 +573,12 @@ def _maybe_set_default_save_args(value, args_):
logging.debug('param_info: %s', param_infos)
logging.debug('save_args: %s', save_args)

# TODO(b/285888834): Allow this to be asynchronous.
metadata_future = None
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
self._write_metadata_file(directory, item, save_args, self._use_zarr3)
metadata_future = await self._write_metadata_file(
directory, item, save_args, self._use_zarr3
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/metadata_write_duration_secs',
time.time() - metadata_write_start_time,
Expand All @@ -589,7 +592,11 @@ def _maybe_set_default_save_args(value, args_):
'/jax/checkpoint/write/async/aggregate_write_duration_secs',
time.time() - aggregate_file_write_start_time,
)
return commit_futures + [aggregate_commit_future]
return (
commit_futures + [aggregate_commit_future] + [metadata_future]
if metadata_future is not None
else commit_futures + [aggregate_commit_future]
)

def save(self, directory: epath.Path, *args, **kwargs):
"""Saves the provided item.
Expand Down Expand Up @@ -824,23 +831,30 @@ def _read_aggregate_file(self, directory: epath.Path) -> PyTree:
else:
return utils.pytree_structure(directory)

def _write_metadata_file(
async def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
):
(directory / METADATA_FILE).write_text(
json.dumps(
tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
).to_json()
)
) -> future.Future:
kvstore = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access
os.fspath(directory), name=METADATA_FILE, use_ocdbt=False
)['kvstore']
tspec = {'driver': 'json', 'kvstore': kvstore}
txn = ts.Transaction()
metadata_ts_context = type_handlers.get_ts_context(use_ocdbt=False)
t = await ts.open(tspec, open=True, context=metadata_ts_context)
metadata_content = tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
)
write_future = t.with_transaction(txn).write(metadata_content)
await write_future
commit_future = txn.commit_async()
return commit_future

def _read_metadata_file(
self, directory: epath.Path
Expand Down

0 comments on commit a09e829

Please sign in to comment.