Skip to content

Commit

Permalink
relocate tmp_seg to a reloadable HDF5
Browse files Browse the repository at this point in the history
defining mmap array location over global does not work on MacOS. So implemented fix where array location is defined and reconnected to within each processing thread.
  • Loading branch information
sophiamaedler committed Apr 27, 2024
1 parent ece990e commit eca69c7
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import gc

import sys
from alphabase.io import tempmmap

class Segmentation(ProcessingStep):
"""Segmentation helper class used for creating segmentation workflows.
Expand Down Expand Up @@ -84,20 +85,20 @@ def process(self):
]

def __init__(self, *args, **kwargs):

#if _tmp_seg is passed as an argument execute this following code (this only applies to some cases)
if "_tmp_seg" in kwargs.keys():
self._tmp_seg = _tmp_seg
if "_tmp_seg_path" in kwargs.keys():
self._tmp_seg_path = _tmp_seg_path

#remove _tmp_seg from kwargs so that underlying classes do not need to account for it
kwargs.pop("_tmp_seg")
kwargs.pop("_tmp_seg_path")

super().__init__(*args, **kwargs)

self.identifier = None
self.window = None
self.input_path = None


def save_classes(self, classes):

#define path where classes should be saved
Expand Down Expand Up @@ -196,6 +197,7 @@ def call_as_shard(self):
#cleanup generated temp dir and variables
del input_image
gc.collect()
shutil.rmtree(TEMP_DIR_NAME) #remove create temp directory to cleanup directory

#write out window location
self.log(f"Writing out window location to file at {self.directory}/window.csv")
Expand Down Expand Up @@ -1040,6 +1042,7 @@ class TimecourseSegmentation(Segmentation):
]

def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

self.index = None
Expand All @@ -1050,7 +1053,7 @@ def __init__(self, *args, **kwargs):
"No BaseSegmentationType defined, please set attribute ``BaseSegmentationMethod``"
)

def initialize_as_shard(self, index, input_path):
def initialize_as_shard(self, index, input_path, _tmp_seg_path):
"""Initialize Segmentation Step with further parameters needed for federated segmentation.
Important:
Expand All @@ -1062,6 +1065,7 @@ def initialize_as_shard(self, index, input_path):
"""
self.index = index
self.input_path = input_path
self._tmp_seg_path = _tmp_seg_path

def call_as_shard(self):
"""Wrapper function for calling a sharded segmentation.
Expand All @@ -1070,8 +1074,6 @@ def call_as_shard(self):
This function is intended for internal use by the :class:`ShardedSegmentation` helper class. In most cases it is not relevant to the creation of custom segmentation workflows.
"""
global _tmp_seg

with h5py.File(self.input_path, "r") as hf:
hdf_input = hf.get("input_images")

Expand All @@ -1088,6 +1090,8 @@ def call_as_shard(self):
_result = super().__call__(input_image)
except Exception:
self.log(traceback.format_exc())

results.append(_result)
self.log(f"Segmentation on index {index} completed.")

return results
Expand All @@ -1106,31 +1110,33 @@ def save_segmentation(
labels (np.array): Numpy array of shape ``(height, width)``. Labels are all data which are saved as integer values. These are mostly segmentation maps with integer values corresponding to the labels of cells.
classes (list(int)): List of all classes in the labels array, which have passed the filtering step. All classes contained in this list will be extracted.
"""
global _tmp_seg
#reconnect to existing HDF5 for memory mapping segmentation results
_tmp_seg = tempmmap.mmap_array_from_path(self._tmp_seg_path)

# size (C, H, W) is expected
# dims are expanded in case (H, W) is passed
labels = np.expand_dims(labels, axis=0) if len(labels.shape) == 2 else labels
classes = np.array(list(classes))

self.log(f"transferring {self.current_index} to temmporray memory mapped array")
self._tmp_seg[self.current_index] = labels
_tmp_seg[self.current_index] = labels

def _initialize_tempmmap_array(self):
global _tmp_seg
# import tempmmap module and reset temp folder location
from alphabase.io import tempmmap
#close connect to temmpmmap file again
del _tmp_seg

def _initialize_tempmmap_array(self):
#reset tempmmap dir
TEMP_DIR_NAME = tempmmap.redefine_temp_location(self.config["cache"])
self.TEMP_DIR_NAME = TEMP_DIR_NAME

# initialize tempmmap array to save segmentation results to
# create an empty HDF5 file prepared for using as a memory mapped temp array to save segmentation results to
# this required when trying to segment so many images that the results can no longer fit into memory
_tmp_seg = tempmmap.array(self.shape_segmentation, dtype=np.int32)
self._tmp_seg = _tmp_seg
_tmp_seg_path = tempmmap.create_empty_mmap(shape = self.shape_segmentation, dtype = np.uint32)
self._tmp_seg_path = _tmp_seg_path

def _transfer_tempmmap_to_hdf5(self):
global _tmp_seg

_tmp_seg = tempmmap.mmap_array_from_path(self._tmp_seg_path)
input_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

# create hdf5 datasets with temp_arrays as input
Expand All @@ -1143,12 +1149,14 @@ def _transfer_tempmmap_to_hdf5(self):
)
hf.create_dataset(
"segmentation",
shape=self._tmp_seg.shape,
shape=_tmp_seg.shape,
chunks=(1, 2, self.shape_input_images[2], self.shape_input_images[3]),
dtype="uint32",
)

hf["segmentation"][:] = self._tmp_seg
#using this loop structure ensures that not all results are loaded in memory at any one timepoint
for i in range(_tmp_seg.shape[0]):
hf["segmentation"][i] = _tmp_seg[i]

dt = h5py.special_dtype(vlen=np.dtype("uint32"))

Expand All @@ -1166,13 +1174,12 @@ def _transfer_tempmmap_to_hdf5(self):
dtype=dt,
)


# delete tempobjects (to cleanup directory)
self.log(f"Tempmmap Folder location {self.TEMP_DIR_NAME} will now be removed.")
shutil.rmtree(self.TEMP_DIR_NAME, ignore_errors=True)

del _tmp_seg, self.TEMP_DIR_NAME, self._tmp_seg
gc. collect()
del _tmp_seg, self.TEMP_DIR_NAME
gc.collect()

def save_image(self, array, save_name="", cmap="magma", **kwargs):
if np.issubdtype(array.dtype.type, np.integer):
Expand Down Expand Up @@ -1306,7 +1313,7 @@ def process(self):
debug=self.debug,
overwrite=self.overwrite,
intermediate_output=self.intermediate_output,
_tmp_seg = self._tmp_seg,
_tmp_seg_path = self._tmp_seg_path,
)

current_shard.initialize_as_shard(indexes, input_path=input_path)
Expand Down Expand Up @@ -1339,7 +1346,6 @@ def initializer_function(self, gpu_id_list):
current_process().gpu_id_list = gpu_id_list

def process(self):
global _tmp_seg
input_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

with h5py.File(input_path, "r") as hf:
Expand All @@ -1358,8 +1364,7 @@ def process(self):

# initialize temp object to write segmentations too
self._initialize_tempmmap_array()

segmentation_list = self.initialize_shard_list(indexes, input_path=input_path)
segmentation_list = self.initialize_shard_list(indexes, input_path=input_path, _tmp_seg_path = self._tmp_seg_path)

# make more verbose output for troubleshooting and timing purposes.
n_threads = self.config["threads"]
Expand Down Expand Up @@ -1389,7 +1394,7 @@ def process(self):
for _ in range(processes_per_GPU):
gpu_id_list.append(gpu_ids)

self.log(f"Beginning segmentation on {available_GPUs}.")
self.log(f"Beginning segmentation on {available_GPUs} available GPUs.")

with Pool(processes=n_processes, initializer=self.initializer_function, initargs=[gpu_id_list]) as pool:
results = list(
Expand All @@ -1413,7 +1418,8 @@ def process(self):
del results
gc.collect()

def initialize_shard_list(self, segmentation_list, input_path):
def initialize_shard_list(self, segmentation_list, input_path, _tmp_seg_path):

_shard_list = []

for i in tqdm(
Expand All @@ -1426,10 +1432,9 @@ def initialize_shard_list(self, segmentation_list, input_path):
debug=self.debug,
overwrite=self.overwrite,
intermediate_output=self.intermediate_output,
_tmp_seg = self._tmp_seg,
)

current_shard.initialize_as_shard(i, input_path)
current_shard.initialize_as_shard(i, input_path, _tmp_seg_path = _tmp_seg_path)
_shard_list.append(current_shard)

self.log(f"Shard list created with {len(_shard_list)} elements.")
Expand Down

0 comments on commit eca69c7

Please sign in to comment.