Skip to content

Commit

Permalink
update tools for filtering datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophia Maedler committed Jan 31, 2024
1 parent c37abc0 commit d21e98d
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions src/sparcstools/utils/dataset_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rasterio.features import rasterize
import numpy.ma as ma
from tqdm.auto import tqdm
import shutil

def get_indexes(project_location, cell_ids, return_annotation = False):

Expand All @@ -20,16 +21,22 @@ def get_indexes(project_location, cell_ids, return_annotation = False):

index_locs = []
for cell_id in tqdm(cell_ids, desc = "getting indexes"):
index_locs.append(lookup[cell_id])
try:
index_locs.append(lookup[cell_id])
except KeyError:
index_locs.append(np.nan)
print(f"cell_id {cell_id} not found in dataset. Skipping...")
continue

if return_annotation:
annotation = pd.DataFrame({"index_hdf5": index_locs, "cell_id":cell_ids})
annotation = annotation.sort_values("index_hdf5")
annotation = annotation.dropna() #remove nan values
return(np.array(annotation.index_hdf5.tolist()), np.array(annotation.cell_id.tolist()))
else:
return(np.sort(index_locs))

def save_cells_to_new_hdf5(project_location, name, cell_ids, annotation = "selected_cells", append = False):
def save_cells_to_new_hdf5(project_location, name, cell_ids, annotation = "selected_cells", append = False, temp_dir = "/tmp"):

#get output directory
outdir = f"{project_location}/extraction/filtered_data/{name}/"
Expand All @@ -49,9 +56,21 @@ def save_cells_to_new_hdf5(project_location, name, cell_ids, annotation = "selec

#get cell images we want to write to new location
print("getting cell images for selected cells...")

with h5py.File(f"{project_location}/extraction/data/single_cells.h5", "r") as hf_in:
cell_images = hf_in.get("single_cell_data")[indexes]


#generate container for single_cell_data
from alphabase.io import tempmmap
TEMP_DIR_NAME = tempmmap.redefine_temp_location(temp_dir)
_, c, x, y = hf_in["single_cell_data"].shape
cell_images = tempmmap.array(shape = (len(indexes), c, x, y), dtype = np.float16)

#actually add the single_cell_data to the container
i = 0
for index in tqdm(indexes):
cell_images[i] = hf_in.get("single_cell_data")[int(index)]
i += 1

#delete file if append is False and it already exists so that we generate a new one
if not append:
if os.path.isfile(outfile):
Expand Down Expand Up @@ -88,7 +107,6 @@ def save_cells_to_new_hdf5(project_location, name, cell_ids, annotation = "selec

#create index file
hf_out.create_dataset('single_cell_index', (n_cells, 2), maxshape = (None, 2), dtype="uint64")


#create dataset
hf_out.create_dataset('single_cell_data', (n_cells, n_channels, x, y),
Expand All @@ -105,6 +123,10 @@ def save_cells_to_new_hdf5(project_location, name, cell_ids, annotation = "selec
hf_out.get("annotation")[:] = annotation_df
print("results saved.")

#cleanup temp directories
del cell_images
shutil.rmtree(TEMP_DIR_NAME, ignore_errors=True)

def _read_napari_csv(path):
# read csv table
shapes = pd.read_csv(path, sep = ",")
Expand All @@ -117,8 +139,15 @@ def _read_napari_csv(path):

for shape_id in shape_ids:
_shapes = shapes.loc[shapes.index_shape == shape_id]
x = _shapes["axis-0"].tolist()
y = _shapes["axis-1"].tolist()
x = _shapes["axis-1"].tolist()
y = _shapes["axis-0"].tolist()

if len(x) < 3:
print("shape with less than 3 points found. Skipping...")
continue
if len(y) < 3:
print("shape with less than 3 points found. Skipping...")
continue

polygon = Polygon(zip(x, y))
polygons.append(polygon)
Expand All @@ -130,7 +159,7 @@ def _generate_mask_polygon(poly, outshape):
img = rasterize(poly, out_shape = (x, y))
return(img.astype("bool"))

def extract_single_cells_napari_area(napari_path, project_location):
def extract_single_cells_napari_area(napari_path, project_location, temp_dir = "/tmp"):

#get name from naparipath
name = os.path.basename(napari_path).split(".")[0]
Expand Down Expand Up @@ -171,4 +200,4 @@ def extract_single_cells_napari_area(napari_path, project_location):
cell_ids = list(cell_ids) #need list type for lookup of indexes

print(f"found {len(cell_ids)} unique cell_ids in selected area. Will now export to new HDF5 single cell dataset.")
save_cells_to_new_hdf5(project_location, cell_ids = cell_ids, name = name, annotation = name, append = False)
save_cells_to_new_hdf5(project_location, cell_ids = cell_ids, name = name, annotation = name, append = False, temp_dir=temp_dir)

0 comments on commit d21e98d

Please sign in to comment.