Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Slide ids turned into floats in split csv when names consist of only number #228

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions datasets/dataset_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from utils.utils import generate_split, nth

def save_splits(split_datasets, column_keys, filename, boolean_style=False):
splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
splits = [split_datasets[i].slide_data['slide_id'].astype(str) for i in range(len(split_datasets))]
if not boolean_style:
df = pd.concat(splits, ignore_index=True, axis=1)
df.columns = column_keys
Expand Down Expand Up @@ -188,7 +188,7 @@ def set_splits(self,start_from=None):

def get_split_from_df(self, all_splits, split_key='train'):
split = all_splits[split_key]
split = split.dropna().reset_index(drop=True)
split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype)

if len(split) > 0:
mask = self.slide_data['slide_id'].isin(split.tolist())
Expand All @@ -203,7 +203,7 @@ def get_merged_split_from_df(self, all_splits, split_keys=['train']):
merged_split = []
for split_key in split_keys:
split = all_splits[split_key]
split = split.dropna().reset_index(drop=True).tolist()
split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype).tolist()
merged_split.extend(split)

if len(split) > 0:
Expand Down Expand Up @@ -244,7 +244,8 @@ def return_splits(self, from_id=True, csv_path=None):

else:
assert csv_path
all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
#all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
all_splits = pd.read_csv(csv_path, dtype=object)
train_split = self.get_split_from_df(all_splits, 'train')
val_split = self.get_split_from_df(all_splits, 'val')
test_split = self.get_split_from_df(all_splits, 'test')
Expand Down