Skip to content

Commit

Permalink
add more skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed Apr 22, 2024
1 parent 9efee01 commit 34276d6
Showing 1 changed file with 135 additions and 2 deletions.
137 changes: 135 additions & 2 deletions src/sparcscore/ml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self, dir_list,
def add_hdf_to_index(self, current_label, path):
try:
input_hdf = h5py.File(path, 'r')
index_handle = input_hdf.get('single_cell_index')
index_handle = input_hdf.get('single_cell_index') # to float

handle_id = len(self.handle_list)
self.handle_list.append(input_hdf.get('single_cell_data'))
Expand Down Expand Up @@ -201,4 +201,137 @@ def __getitem__(self, idx):
class HDF5SingleCellDatasetRegression(Dataset):
"""
Class for handling SPARCSpy single cell datasets stored in HDF5 files for regression tasks.
"""
"""

HDF_FILETYPES = ["hdf", "hf", "h5", "hdf5"] # supported hdf5 filetypes

def __init__(self,
dir_list,
dir_labels,
root_dir,
max_level=5,
transform=None,
return_id=False,
return_fake_id=False,
select_channel=None):

self.root_dir = root_dir
self.dir_labels = dir_labels
self.dir_list = dir_list
self.transform = transform

self.handle_list = []
self.data_locator = []

self.select_channel = select_channel

# scan all directories
for i, directory in enumerate(dir_list):
path = os.path.join(self.root_dir, directory)
current_label = self.dir_labels[i]

#check if "directory" is a path to specific hdf5
filetype = directory.split(".")[-1]
filename = directory.split(".")[0]

if filetype in self.HDF_FILETYPES:
self.add_hdf_to_index(current_label, directory)

else:
# recursively scan for files
self.scan_directory(path, current_label, max_level)

# print dataset stats at the end
self.return_id = return_id
self.return_fake_id = return_fake_id
self.stats()


def add_hdf_to_index(self, current_label, path):
try:
input_hdf = h5py.File(path, 'r')
index_handle = input_hdf.get('single_cell_index') # to float

handle_id = len(self.handle_list)
self.handle_list.append(input_hdf.get('single_cell_data'))

for row in index_handle:
self.data_locator.append([current_label, handle_id]+list(row))
except:
return

def scan_directory(self, path, current_label, levels_left):

# iterates over all files and folders in a directory
# hdf5 files are added to the index
# subfolders are recursively scanned

if levels_left > 0:

# get files and directories at current level
input_list = os.listdir(path)
current_level_directories = [os.path.join(path, name) for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]

current_level_files = [ name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]

for i, file in enumerate(current_level_files):
filetype = file.split(".")[-1]
filename = file.split(".")[0]

if filetype in self.HDF_FILETYPES:

self.add_hdf_to_index(current_label, os.path.join(path, file))

# recursively scan subdirectories
for subdirectory in current_level_directories:
self.scan_directory(subdirectory, current_label, levels_left-1)

else:
return

def stats(self): # print dataset statistics
labels = [el[0] for el in self.data_locator]

print("Total: {}".format(len(labels)))

for l in set(labels):
print("{}: {}".format(l,labels.count(l)))

def __len__(self):
return len(self.data_locator)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist() # convert tensor to list

# get the label, filename and directory for the current dataset
data_info = self.data_locator[idx]

if self.select_channel is not None:
cell_tensor = self.handle_list[data_info[1]][data_info[2], self.select_channel]
t = torch.from_numpy(cell_tensor)
t = torch.unsqueeze(t, 0)
else:
cell_tensor = self.handle_list[data_info[1]][data_info[2]]
t = torch.from_numpy(cell_tensor)

t = t.float() # convert to float tensor

if self.transform:
t = self.transform(t) # apply transformation
"""
if not list(t.shape) == list(torch.Size([1,128,128])):
t = torch.zeros((1,128,128))
"""
if self.return_id and self.return_fake_id:
raise ValueError("either return_id or return_fake_id should be set")

if self.return_id:
ids = int(data_info[3])
sample = (t, torch.tensor(data_info[0]), torch.tensor(ids)) # return data, label, and id
elif self.return_fake_id:
sample = (t, torch.tensor(data_info[0]), torch.tensor(0)) # return data, label, and fake id
else:
sample = (t, torch.tensor(data_info[0])) # return data and label

return sample

0 comments on commit 34276d6

Please sign in to comment.