Skip to content

Commit

Permalink
Add write_record_metadata to PyTorchFileWriter
Browse files Browse the repository at this point in the history
ghstack-source-id: 2b57a55587f881fdaa747e7716290e19f0ef0224
Pull Request resolved: pytorch#125184
  • Loading branch information
mikaylagawarecki authored and albanD committed May 7, 2024
1 parent 22bcfc2 commit 6dcb9dd
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 25 deletions.
57 changes: 41 additions & 16 deletions caffe2/serialize/inline_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,35 @@ size_t ostream_write_func(
return ret;
}

// This func will not update combined_uncomp_crc32_ with the uncomp_crc32
// since there is no way to get the uncomp_crc32 when no buffer is provided.
size_t ostream_seek_func(
void* pOpaque,
mz_uint64 file_ofs,
size_t n) {
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
if (self->current_pos_ != file_ofs) {
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
}
size_t ret = self->seek_func_(n);
if (self->current_pos_ + n != ret) {
self->err_seen_ = true;
}
self->current_pos_ += n;
return n;
}

PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
: archive_name_(basename(file_name)) {
setup(file_name);
}

PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func)
const std::function<size_t(const void*, size_t)> writer_func,
const std::function<size_t(size_t)> seek_func)
: archive_name_("archive"),
writer_func_(writer_func) {
writer_func_(writer_func),
seek_func_(seek_func) {
setup(archive_name_);
}

Expand Down Expand Up @@ -657,10 +677,15 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;
};
seek_func_ = [this](size_t nbytes) -> size_t {
file_stream_.seekp(nbytes, std::ios_base::cur);
return file_stream_.tellp();
};
}

ar_->m_pIO_opaque = this;
ar_->m_pWrite = ostream_write_func;
ar_->m_pSeek = ostream_seek_func;

mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
valid("initializing archive ", file_name.c_str());
Expand Down Expand Up @@ -690,20 +715,20 @@ void PyTorchStreamWriter::writeRecord(
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2(
ar_.get(),
full_name.c_str(),
data,
size,
nullptr,
0,
flags,
0,
0,
nullptr,
padding_.c_str(),
padding_size,
nullptr,
0);
/*pZip=*/ar_.get(),
/*pArchive_name=*/full_name.c_str(),
/*pBuf=*/data,
/*buf_size=*/size,
/*pComment=*/nullptr,
/*comment_size=*/0,
/*level_and_flags=*/flags,
/*uncomp_size=*/0,
/*uncomp_crc32=*/0,
/*last_modified=*/nullptr,
/*user_extra_data=*/padding_.c_str(),
/*user_extra_data_len=*/padding_size,
/*user_extra_data_central=*/nullptr,
/*user_extra_data_central_len=*/0);
valid("writing file ", name.c_str());
files_written_.insert(name);
}
Expand Down
17 changes: 16 additions & 1 deletion caffe2/serialize/inline_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,21 @@ class TORCH_API PyTorchStreamReader final {
size_t additional_reader_size_threshold_;
};

namespace {

size_t default_seek_func(size_t nbytes) {
TORCH_CHECK(false, "attempting to write record metadata but seek_func unimplemented, please implement seek_func");
return 0;
}

} // namespace

class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func);
const std::function<size_t(const void*, size_t)> writer_func,
const std::function<size_t(size_t)> seek_func = default_seek_func);

void setMinVersion(const uint64_t version);

Expand Down Expand Up @@ -246,6 +256,7 @@ class TORCH_API PyTorchStreamWriter final {
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
std::function<size_t(size_t)> seek_func_;
uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_;

Expand All @@ -259,6 +270,10 @@ class TORCH_API PyTorchStreamWriter final {
uint64_t file_ofs,
const void* pBuf,
size_t n);
friend size_t ostream_seek_func(
void* pOpaque,
uint64_t file_ofs,
size_t n);
};

namespace detail {
Expand Down
44 changes: 44 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4000,6 +4000,50 @@ def test_serialization_dtype(self, dtype, weights_only):
y['even'][0] = torch.tensor(-0.25, dtype=dtype)
self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))

@parametrize('filename', (True, False))
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_filewriter_metadata_writing(self, filename):
sd = torch.nn.Linear(3, 5).state_dict()
weight_nbytes = sd['weight'].untyped_storage().nbytes()
bias_nbytes = sd['bias'].untyped_storage().nbytes()
# TemporaryFileName will give a string
# NamedTemporaryFile will be treated as a buffer
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile

with file_creation_func() as f, file_creation_func() as g:
# save state_dict in f
torch.save(sd, f)
if not filename:
f.seek(0)
# extract 'data.pkl' for use in our fake checkpoint
with torch.serialization._open_file_like(f, 'rb') as opened_file:
with torch.serialization._open_zipfile_reader(opened_file) as zip_file:
data_file = io.BytesIO(zip_file.get_record('data.pkl'))
data_0_offset = zip_file.get_record_offset('data/0')
data_1_offset = zip_file.get_record_offset('data/1')

# write nulls for 'data/0' and 'data/1'
with open(f if filename else f.name, 'rb+') as opened_f:
opened_f.seek(data_0_offset)
opened_f.write(b'0' * weight_nbytes)
opened_f.seek(data_1_offset)
opened_f.write(b'0' * bias_nbytes)

with torch.serialization._open_zipfile_writer(g) as zip_file:
data_value = data_file.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))
zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
# Only write metadata for storages
zip_file.write_record_metadata('data/0', weight_nbytes)
zip_file.write_record_metadata('data/1', bias_nbytes)

if not filename:
f.seek(0)
g.seek(0)
sd_loaded = torch.load(g)
sd_loaded_ref = torch.load(f)
self.assertEqual(sd_loaded, sd_loaded_ref)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)
Expand Down
20 changes: 15 additions & 5 deletions third_party/miniz-2.1.0/miniz.c
Original file line number Diff line number Diff line change
Expand Up @@ -6250,6 +6250,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
mz_uint32 extra_size = 0;
mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
mz_uint16 bit_flags = 0;
mz_bool write_metadata_only = buf_size && !pBuf;

if ((int)level_and_flags < 0)
level_and_flags = MZ_DEFAULT_LEVEL;
Expand All @@ -6263,7 +6264,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
level = level_and_flags & 0xF;
store_data_uncompressed = ((!level) || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA));

if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || ((buf_size) && (!pBuf)) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER);

pState = pZip->m_pState;
Expand Down Expand Up @@ -6308,7 +6309,9 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n

if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
{
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
if (!write_metadata_only) {
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
}
uncomp_size = buf_size;
if (uncomp_size <= 3)
{
Expand All @@ -6330,8 +6333,8 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
if (!pState->m_zip64)
{
/* Bail early if the archive would obviously become too large */
if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len
+ MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF)
{
Expand Down Expand Up @@ -6442,7 +6445,14 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n

if (store_data_uncompressed)
{
if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size)
mz_bool write_failed;
if (write_metadata_only) {
write_failed = pZip->m_pSeek(pZip->m_pIO_opaque, cur_archive_file_ofs, buf_size) != buf_size;
} else {
write_failed = pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size;
}

if (write_failed)
{
pZip->m_pFree(pZip->m_pAlloc_opaque, pComp);
return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED);
Expand Down
6 changes: 4 additions & 2 deletions third_party/miniz-2.1.0/miniz.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@



/* Defines to completely disable specific portions of miniz.c:
/* Defines to completely disable specific portions of miniz.c:
If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */

/* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */
Expand All @@ -139,7 +139,7 @@
/* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */
#define MINIZ_NO_ZLIB_COMPATIBLE_NAMES

/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc
callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user
functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */
Expand Down Expand Up @@ -980,6 +980,7 @@ typedef struct

typedef size_t (*mz_file_read_func)(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n);
typedef size_t (*mz_file_write_func)(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n);
typedef size_t (*mz_file_seek_func)(void *pOpaque, mz_uint64 file_ofs, size_t n);
typedef mz_bool (*mz_file_needs_keepalive)(void *pOpaque);

struct mz_zip_internal_state_tag;
Expand Down Expand Up @@ -1071,6 +1072,7 @@ typedef struct mz_zip_archive /* note: added name so it can be forward declared

mz_file_read_func m_pRead;
mz_file_write_func m_pWrite;
mz_file_seek_func m_pSeek;
mz_file_needs_keepalive m_pNeeds_keepalive;
void *m_pIO_opaque;

Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,9 +1394,21 @@ void initJITBindings(PyObject* module) {
buffer.attr("write")(std::move(memory_view));
return size;
};
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
auto seek_func = [=](size_t offset) {
auto current_pos = py::cast<size_t>(buffer.attr("tell")());
buffer.attr("seek")(
offset, py::module::import("os").attr("SEEK_CUR"));
return current_pos + offset;
};
return std::make_unique<PyTorchStreamWriter>(
std::move(writer_func), std::move(seek_func));
}))
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
.def(
"write_record_metadata",
[](PyTorchStreamWriter& self, const std::string& name, size_t size) {
return self.writeRecord(name, nullptr, size);
})
.def(
"write_record",
[](PyTorchStreamWriter& self,
Expand Down

0 comments on commit 6dcb9dd

Please sign in to comment.