Skip to content

Commit

Permalink
[RUNTIME] Simplify caching API
Browse files Browse the repository at this point in the history
This simplifies the caching API to require callers pass the entire list
of files to cache in a "group" in the `put_group` method. This allows
caching backends to e.g. serialize these files into a single cache blob,
which can avoid issues with caching atomicity.
  • Loading branch information
andrewjcg committed Mar 22, 2024
1 parent 5ac38e7 commit 1cd5a85
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 81 deletions.
4 changes: 2 additions & 2 deletions python/test/backend/test_device_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self):
f.write(src)
so = build_for_backend("ext_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
cache_path = cache.put(fname, fname.encode("utf-8"))
import importlib.util
spec = importlib.util.spec_from_file_location("ext_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
Expand Down Expand Up @@ -185,7 +185,7 @@ def make_launcher_stub(self, name, signature, constants):
f.write(src)
so = build_for_backend(name, src_path, tmpdir)
with open(so, "rb") as f:
so_path = so_cache_manager.put(f.read(), so_name, binary=True)
so_path = so_cache_manager.put(so_name, f.read())
type(self).stub_so_path = so_path
return so_path
else:
Expand Down
25 changes: 13 additions & 12 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,10 @@ def compile(src, target=None, options=None):
enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
metadata_filename = f"{src.name}.json"
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
metadata_path = metadata_group.get(metadata_filename)
if metadata_path is not None:
cache_key = src.name
metadata_group = fn_cache_manager.get_group(cache_key)
if metadata_group is not None:
# cache hit!
metadata = json.loads(Path(metadata_path).read_text())
return CompiledKernel(src, metadata_group, hash)
# initialize metadata
metadata = {
Expand All @@ -266,21 +264,24 @@ def compile(src, target=None, options=None):
except Exception as e:
filter_traceback(e)
raise
metadata_group = {}
for ext, compile_ir in list(stages.items())[first_stage:]:
next_module = compile_ir(module, metadata)
ir_filename = f"{src.name}.{ext}"
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
if isinstance(next_module, bytes):
data = next_module
else:
data = str(next_module).encode("utf-8")
metadata_group[ir_filename] = data
if fn_dump_manager is not None:
fn_dump_manager.put(next_module, ir_filename)
if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)):
fn_dump_manager.put(ir_filename, data)
if fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None:
print(f"\nOverriding kernel with file {ir_filename}")
full_name = fn_override_manager.get_file(ir_filename)
next_module = parse(full_name, ext, context)
module = next_module
# write-back metadata
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
metadata_group[f"{src.name}.json"] = json.dumps(metadata, default=vars).encode("utf-8")
metadata_group = fn_cache_manager.put_group(cache_key, metadata_group)
# return handle to compiled kernel
return CompiledKernel(src, metadata_group, hash)

Expand Down
114 changes: 49 additions & 65 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import json
import os
import pickle
import random
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -26,19 +27,19 @@ def __init__(self, key):
pass

@abstractmethod
def get_file(self, filename) -> Optional[str]:
def get_file(self, key: str) -> Optional[str]:
pass

@abstractmethod
def put(self, data, filename, binary=True) -> str:
def put(self, key: str, data: bytes) -> str:
pass

@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
def put_group(self, key: str, files: Dict[str, bytes]) -> Dict[str, str]:
pass

@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
def get_group(self, key: str) -> Optional[Dict[str, str]]:
pass


Expand Down Expand Up @@ -68,20 +69,20 @@ def __init__(self, key, override=False, dump=False):
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)

def has_file(self, filename) -> bool:
def _has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))

def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
if self._has_file(filename):
return self._make_path(filename)
else:
return None

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
def get_group(self, key: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{key}.json"
if not self._has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
Expand All @@ -97,19 +98,17 @@ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
return result

# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]) -> str:
def _put_group_metadata(self, key: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": group})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
grp_contents = json.dumps({"child_paths": group}).encode("utf-8")
grp_filename = f"__grp__{key}.json"
return self.put(grp_filename, grp_contents)

def put(self, data, filename, binary=True) -> str:
def put(self, filename: str, data: bytes) -> str:
assert isinstance(data, bytes), f"{filename} data is not bytes: {type(data)}"
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
Expand All @@ -118,14 +117,20 @@ def put(self, data, filename, binary=True) -> str:
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
with open(temp_path, mode="wb") as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
return filepath

def put_group(self, key, group: Dict[str, bytes]) -> Dict[str, str]:
result = {}
for name, data in group.items():
result[name] = self.put(name, data)
self._put_group_metadata(key, result)
return result


class RemoteCacheBackend:
"""
Expand All @@ -136,11 +141,11 @@ def __init__(self, key: str):
pass

@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
def get(self, key: str) -> Optional[bytes]:
pass

@abstractmethod
def put(self, filename: str, data: bytes):
def put(self, key: str, data: bytes):
pass


Expand All @@ -158,11 +163,10 @@ def __init__(self, key):
def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)

def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}
def get(self, filename: str) -> Optional[bytes]:
return self._redis.get(self._get_key(filename))

def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
def put(self, filename: str, data: bytes):
self._redis.set(self._get_key(filename), data)


Expand All @@ -182,65 +186,45 @@ def __init__(self, key, override=False, dump=False):
# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)

def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)

def get_file(self, filename: str) -> Optional[str]:
def get_file(self, key: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)
return self._file_cache_manager.get_file(key)

# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
data = self._backend.get(key)
if data is None:
return None
(_, data), = results.items()
return self._materialize(filename, data)

def put(self, data, filename: str, binary=True) -> str:
return self._file_cache_manager.put(key, data)

def put(self, key: str, data: bytes) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)
return self._file_cache_manager.put(key, data)

if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)
self._backend.put(key, data)

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
return self._file_cache_manager.put(key, data)

def get_group(self, key: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)

grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
data = self._backend.get(key)
if data is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)

result = None
return self._file_cache_manager.put_group(key, pickle.loads(data))

# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)

return result

def put_group(self, filename: str, group: Dict[str, str]):
def put_group(self, key: str, group: Dict[str, bytes]) -> Dict[str, str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)
return self._file_cache_manager.put_group(key, group)

self._backend.put(key, pickle.dumps(group))

grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)
return self._file_cache_manager.put_group(key, group)


__cache_cls = FileCacheManager
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def compile_module_from_src(src, name):
f.write(src)
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
cache_path = cache.put(f"{name}.so", f.read())
import importlib.util
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def compile_module_from_src(src, name):
f.write(src)
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
cache_path = cache.put(f"{name}.so", f.read())
import importlib.util
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
Expand Down

0 comments on commit 1cd5a85

Please sign in to comment.