From 17c85fb9f2d421d3cff9a49aa084345d15e9c491 Mon Sep 17 00:00:00 2001 From: ff98li Date: Mon, 26 Feb 2024 15:14:28 -0500 Subject: [PATCH] Fix:slide ids turned into floats in split csv when names consist of only numerical characters --- datasets/dataset_generic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datasets/dataset_generic.py b/datasets/dataset_generic.py index 32da3967..d212c8e5 100755 --- a/datasets/dataset_generic.py +++ b/datasets/dataset_generic.py @@ -15,7 +15,7 @@ from utils.utils import generate_split, nth def save_splits(split_datasets, column_keys, filename, boolean_style=False): - splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))] + splits = [split_datasets[i].slide_data['slide_id'].astype(str) for i in range(len(split_datasets))] if not boolean_style: df = pd.concat(splits, ignore_index=True, axis=1) df.columns = column_keys @@ -188,7 +188,7 @@ def set_splits(self,start_from=None): def get_split_from_df(self, all_splits, split_key='train'): split = all_splits[split_key] - split = split.dropna().reset_index(drop=True) + split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype) if len(split) > 0: mask = self.slide_data['slide_id'].isin(split.tolist()) @@ -203,7 +203,7 @@ def get_merged_split_from_df(self, all_splits, split_keys=['train']): merged_split = [] for split_key in split_keys: split = all_splits[split_key] - split = split.dropna().reset_index(drop=True).tolist() + split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype).tolist() merged_split.extend(split) if len(split) > 0: @@ -244,7 +244,8 @@ def return_splits(self, from_id=True, csv_path=None): else: assert csv_path - all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01. + #all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01. + all_splits = pd.read_csv(csv_path, dtype=object) train_split = self.get_split_from_df(all_splits, 'train') val_split = self.get_split_from_df(all_splits, 'val') test_split = self.get_split_from_df(all_splits, 'test')