Skip to content

Commit

Permalink
change stats method and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed Apr 22, 2024
1 parent 8746923 commit 23c64ee
Showing 1 changed file with 65 additions and 77 deletions.
142 changes: 65 additions & 77 deletions src/sparcscore/ml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,121 +202,109 @@ class HDF5SingleCellDatasetRegression(Dataset):
"""
Class for handling SPARCSpy single cell datasets stored in HDF5 files for regression tasks.
"""

HDF_FILETYPES = ["hdf", "hf", "h5", "hdf5"] # supported hdf5 filetypes

def __init__(self,
dir_list,
target_values,
root_dir,
max_level=5,
transform=None,
return_id=False,
return_fake_id=False,
select_channel=None):

self.root_dir = root_dir
dir_list: list[str],
target_values: list[float],
root_dir: str,
max_level: int = 5,
transform = None,
return_id: bool = False,
return_fake_id: bool = False,
select_channel = None):

self.root_dir = root_dir
self.target_values = target_values
self.dir_list = dir_list
self.transform = transform

self.select_channel = select_channel
self.handle_list = []
self.data_locator = []

self.select_channel = select_channel

# scan all directories
# scan all directories in dir_list
for i, directory in enumerate(dir_list):
path = os.path.join(self.root_dir, directory)
current_label = self.dir_labels[i]
path = os.path.join(self.root_dir, directory) # get full path
current_target = self.target_values[i] # get target value
filetype = directory.split(".")[-1] # get filetype

#check if "directory" is a path to specific hdf5
filetype = directory.split(".")[-1]
#filename = directory.split(".")[0]

if filetype in self.HDF_FILETYPES:
self.add_hdf_to_index(current_label, directory)

self.add_hdf_to_index(current_target, directory) # check if "directory" is a path to specific hdf5 and add to index
else:
# recursively scan for files
self.scan_directory(path, current_label, max_level)

# print dataset stats at the end
self.return_id = return_id
self.return_fake_id = return_fake_id
self.stats()
self.scan_directory(path, current_target, max_level) # recursively scan for files

self.return_id = return_id # return id
self.return_fake_id = return_fake_id # return fake id
self.stats() # print dataset stats at the end


def add_hdf_to_index(self, current_target, path):
try:
input_hdf = h5py.File(path, 'r') # read hdf5 file
index_handle = input_hdf.get('single_cell_index') # to float!!!
index_handle = input_hdf.get('single_cell_index') # get index handle

handle_id = len(self.handle_list)
self.handle_list.append(input_hdf.get('single_cell_data'))
handle_id = len(self.handle_list) # get handle id
self.handle_list.append(input_hdf.get('single_cell_data')) # append data handle

for row in index_handle:
self.data_locator.append([current_target, handle_id] + list(row))
self.data_locator.append([current_target, handle_id] + list(row)) # append data locator with target, handle id, and row
except:
return

def scan_directory(self, path, current_target, levels_left):

# iterates over all files and folders in a directory
# hdf5 files are added to the index
# subfolders are recursively scanned

if levels_left > 0:

# get files and directories at current level
input_list = os.listdir(path)
current_level_directories = [os.path.join(path, name) for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]

current_level_files = [ name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]
def scan_directory(self, path, current_target, levels_left):
if levels_left > 0: # iterate over all files and folders in a directory if levels_left > 0
current_level_directories = [os.path.join(path, name) for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))] # get directories
current_level_files = [ name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))] # get files

for i, file in enumerate(current_level_files):
filetype = file.split(".")[-1]
filename = file.split(".")[0]

for i, file in enumerate(current_level_files): # iterate over files from current level
filetype = file.split(".")[-1] # get filetypes

if filetype in self.HDF_FILETYPES:
self.add_hdf_to_index(current_target, os.path.join(path, file))

# recursively scan subdirectories
for subdirectory in current_level_directories:
self.scan_directory(subdirectory, current_target, levels_left-1)

self.add_hdf_to_index(current_target, os.path.join(path, file)) # add hdf5 files to index if filetype is supported

for subdirectory in current_level_directories: # recursively scan subdirectories
self.scan_directory(subdirectory, current_target, levels_left-1)
else:
return

def stats(self):
targets = [info[0] for info in self.data_locator]

def stats(self): # print dataset statistics
labels = [el[0] for el in self.data_locator]
print("Total: {}".format(len(labels)))
for l in set(labels):
print("{}: {}".format(l,labels.count(l)))
targets = np.array(targets, dtype=float)

mean_target = np.mean(targets)
median_target = np.median(targets)
std_target = np.std(targets)
min_target = np.min(targets)
max_target = np.max(targets)

print(f"Total samples: {len(targets)}")
print(f"Mean of target values: {mean_target:.2f}")
print(f"Median of target values: {median_target:.2f}")
print(f"SD of targets: {std_target:.2f}")
print(f"Min target: {min_target:.2f}")
print(f"Max target: {max_target:.2f}")

def __len__(self):
return len(self.data_locator)
return len(self.data_locator) # return length of data locator

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist() # convert tensor to list
idx = idx.tolist() # convert tensor to list

# get the label, filename and directory for the current dataset
data_info = self.data_locator[idx]
data_info = self.data_locator[idx] # get the data info for the current index, such as target, handle id, and row

if self.select_channel is not None:
if self.select_channel is not None: # select a specific channel
cell_tensor = self.handle_list[data_info[1]][data_info[2], self.select_channel]
t = torch.from_numpy(cell_tensor).float()
t = torch.unsqueeze(t, 0)
else:
cell_tensor = self.handle_list[data_info[1]][data_info[2]]
t = torch.from_numpy(cell_tensor).float()

#t = t.float() # convert to float tensor
t = torch.from_numpy(cell_tensor).float() # convert to float tensor
t = torch.unsqueeze(t, 0) # add channel dimension to tensor
else:
cell_tensor = self.handle_list[data_info[1]][data_info[2]]
t = torch.from_numpy(cell_tensor).float() # convert to float tensor

if self.transform:
t = self.transform(t) # apply transformation
t = self.transform(t) # apply transformation to the data

target = torch.tensor(data_info[0], dtype=torch.float)
target = torch.tensor(data_info[0], dtype=torch.float) # get target value

return (t, target)
return (t, target) # return data and target value

0 comments on commit 23c64ee

Please sign in to comment.