diff --git a/python/test/backend/test_device_backend.py b/python/test/backend/test_device_backend.py index 8b0e4605ef6..589fbe07e6c 100644 --- a/python/test/backend/test_device_backend.py +++ b/python/test/backend/test_device_backend.py @@ -100,7 +100,7 @@ def __init__(self): f.write(src) so = build_for_backend("ext_utils", src_path, tmpdir) with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) + cache_path = cache.put(fname, fname.encode("utf-8")) import importlib.util spec = importlib.util.spec_from_file_location("ext_utils", cache_path) mod = importlib.util.module_from_spec(spec) @@ -185,7 +185,7 @@ def make_launcher_stub(self, name, signature, constants): f.write(src) so = build_for_backend(name, src_path, tmpdir) with open(so, "rb") as f: - so_path = so_cache_manager.put(f.read(), so_name, binary=True) + so_path = so_cache_manager.put(so_name, f.read()) type(self).stub_so_path = so_path return so_path else: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index aeda2dd680f..2e171147290 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -237,12 +237,10 @@ def compile(src, target=None, options=None): enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" fn_override_manager = get_override_manager(src.hash()) if enable_override else None fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None - metadata_filename = f"{src.name}.json" - metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) - if metadata_path is not None: + cache_key = src.name + metadata_group = fn_cache_manager.get_group(cache_key) + if metadata_group is not None: # cache hit! - metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) # initialize metadata metadata = { @@ -266,21 +264,24 @@ def compile(src, target=None, options=None): except Exception as e: filter_traceback(e) raise + metadata_group = {} for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) ir_filename = f"{src.name}.{ext}" - metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if isinstance(next_module, bytes): + data = next_module + else: + data = str(next_module).encode("utf-8") + metadata_group[ir_filename] = data if fn_dump_manager is not None: - fn_dump_manager.put(next_module, ir_filename) - if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + fn_dump_manager.put(ir_filename, data) + if fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None: print(f"\nOverriding kernel with file {ir_filename}") - full_name = fn_override_manager.get_file(ir_filename) next_module = parse(full_name, ext, context) module = next_module # write-back metadata - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, - binary=False) - fn_cache_manager.put_group(metadata_filename, metadata_group) + metadata_group[f"{src.name}.json"] = json.dumps(metadata, default=vars).encode("utf-8") + metadata_group = fn_cache_manager.put_group(cache_key, metadata_group) # return handle to compiled kernel return CompiledKernel(src, metadata_group, hash) diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 2e8d70ea4ef..5376cf385fa 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -1,6 +1,7 @@ import importlib import json import os +import pickle import random from abc import ABC, abstractmethod from pathlib import Path @@ -26,19 +27,19 @@ def __init__(self, key): pass @abstractmethod - def get_file(self, filename) -> Optional[str]: + def get_file(self, key: str) -> Optional[str]: pass @abstractmethod - def put(self, data, filename, binary=True) -> str: + def put(self, key: str, data: bytes) -> str: pass @abstractmethod - def get_group(self, filename: str) -> Optional[Dict[str, str]]: + def put_group(self, key: str, files: Dict[str, bytes]) -> Dict[str, str]: pass @abstractmethod - def put_group(self, filename: str, group: Dict[str, str]): + def get_group(self, key: str) -> Optional[Dict[str, str]]: pass @@ -68,20 +69,20 @@ def __init__(self, key, override=False, dump=False): def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) - def has_file(self, filename) -> bool: + def _has_file(self, filename) -> bool: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") return os.path.exists(self._make_path(filename)) def get_file(self, filename) -> Optional[str]: - if self.has_file(filename): + if self._has_file(filename): return self._make_path(filename) else: return None - def get_group(self, filename: str) -> Optional[Dict[str, str]]: - grp_filename = f"__grp__{filename}" - if not self.has_file(grp_filename): + def get_group(self, key: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{key}.json" + if not self._has_file(grp_filename): return None grp_filepath = self._make_path(grp_filename) with open(grp_filepath) as f: @@ -97,19 +98,17 @@ def get_group(self, filename: str) -> Optional[Dict[str, str]]: return result # Note a group of pushed files as being part of a group - def put_group(self, filename: str, group: Dict[str, str]) -> str: + def _put_group_metadata(self, key: str, group: Dict[str, str]) -> str: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") - grp_contents = json.dumps({"child_paths": group}) - grp_filename = f"__grp__{filename}" - return self.put(grp_contents, grp_filename, binary=False) + grp_contents = json.dumps({"child_paths": group}).encode("utf-8") + grp_filename = f"__grp__{key}.json" + return self.put(grp_filename, grp_contents) - def put(self, data, filename, binary=True) -> str: + def put(self, filename: str, data: bytes) -> str: + assert isinstance(data, bytes), f"{filename} data is not bytes: {type(data)}" if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") - binary = isinstance(data, bytes) - if not binary: - data = str(data) assert self.lock_path is not None filepath = self._make_path(filename) # Random ID to avoid any collisions @@ -118,14 +117,20 @@ def put(self, data, filename, binary=True) -> str: pid = os.getpid() # use tempfile to be robust against program interruptions temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" - mode = "wb" if binary else "w" - with open(temp_path, mode) as f: + with open(temp_path, mode="wb") as f: f.write(data) # Replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_path, filepath) return filepath + def put_group(self, key, group: Dict[str, bytes]) -> Dict[str, str]: + result = {} + for name, data in group.items(): + result[name] = self.put(name, data) + self._put_group_metadata(key, result) + return result + class RemoteCacheBackend: """ @@ -136,11 +141,11 @@ def __init__(self, key: str): pass @abstractmethod - def get(self, filenames: List[str]) -> Dict[str, bytes]: + def get(self, key: str) -> Optional[bytes]: pass @abstractmethod - def put(self, filename: str, data: bytes): + def put(self, key: str, data: bytes): pass @@ -158,11 +163,10 @@ def __init__(self, key): def _get_key(self, filename: str) -> str: return self._key_fmt.format(key=self._key, filename=filename) - def get(self, filenames: List[str]) -> Dict[str, str]: - results = self._redis.mget([self._get_key(f) for f in filenames]) - return {filename: result for filename, result in zip(filenames, results) if result is not None} + def get(self, filename: str) -> Optional[bytes]: + return self._redis.get(self._get_key(filename)) - def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + def put(self, filename: str, data: bytes): self._redis.set(self._get_key(filename), data) @@ -182,65 +186,45 @@ def __init__(self, key, override=False, dump=False): # Use a `FileCacheManager` to materialize remote cache paths locally. self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) - def _materialize(self, filename: str, data: bytes): - # We use a backing `FileCacheManager` to provide the materialized data. - return self._file_cache_manager.put(data, filename, binary=True) - - def get_file(self, filename: str) -> Optional[str]: + def get_file(self, key: str) -> Optional[str]: # We don't handle the dump/override cases. if self._dump or self._override: - return self._file_cache_manager.get_file(filename) + return self._file_cache_manager.get_file(key) - # We always check the remote cache backend -- even if our internal file- - # based cache has the item -- to make sure LRU accounting works as - # expected. - results = self._backend.get([filename]) - if len(results) == 0: + data = self._backend.get(key) + if data is None: return None - (_, data), = results.items() - return self._materialize(filename, data) - def put(self, data, filename: str, binary=True) -> str: + return self._file_cache_manager.put(key, data) + + def put(self, key: str, data: bytes) -> str: # We don't handle the dump/override cases. if self._dump or self._override: - return self._file_cache_manager.put(data, filename, binary=binary) + return self._file_cache_manager.put(key, data) - if not isinstance(data, bytes): - data = str(data).encode("utf-8") - self._backend.put(filename, data) - return self._materialize(filename, data) + self._backend.put(key, data) - def get_group(self, filename: str) -> Optional[Dict[str, str]]: + return self._file_cache_manager.put(key, data) + + def get_group(self, key: str) -> Optional[Dict[str, str]]: # We don't handle the dump/override cases. if self._dump or self._override: return self._file_cache_manager.get_group(filename) - grp_filename = f"__grp__{filename}" - grp_filepath = self.get_file(grp_filename) - if grp_filepath is None: + data = self._backend.get(key) + if data is None: return None - with open(grp_filepath) as f: - grp_data = json.load(f) - child_paths = grp_data.get("child_paths", None) - result = None + return self._file_cache_manager.put_group(key, pickle.loads(data)) - # Found group data. - if child_paths is not None: - result = {} - for child_path, data in self._backend.get(child_paths).items(): - result[child_path] = self._materialize(child_path, data) - - return result - - def put_group(self, filename: str, group: Dict[str, str]): + def put_group(self, key: str, group: Dict[str, bytes]) -> Dict[str, str]: # We don't handle the dump/override cases. if self._dump or self._override: - return self._file_cache_manager.put_group(filename, group) + return self._file_cache_manager.put_group(key, group) + + self._backend.put(key, pickle.dumps(group)) - grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) - grp_filename = f"__grp__{filename}" - return self.put(grp_contents, grp_filename) + return self._file_cache_manager.put_group(key, group) __cache_cls = FileCacheManager diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 1247c91df4e..9d202153a79 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -22,7 +22,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f"{name}.so", f.read()) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 13cf8a620ef..21a85848daa 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -55,7 +55,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f"{name}.so", f.read()) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec)