Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions datasets/dataset_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from utils.utils import generate_split, nth

def save_splits(split_datasets, column_keys, filename, boolean_style=False):
splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
splits = [split_datasets[i].slide_data['slide_id'].astype(str) for i in range(len(split_datasets))]
if not boolean_style:
df = pd.concat(splits, ignore_index=True, axis=1)
df.columns = column_keys
Expand Down Expand Up @@ -188,7 +188,7 @@ def set_splits(self,start_from=None):

def get_split_from_df(self, all_splits, split_key='train'):
split = all_splits[split_key]
split = split.dropna().reset_index(drop=True)
split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype)

if len(split) > 0:
mask = self.slide_data['slide_id'].isin(split.tolist())
Expand All @@ -203,7 +203,7 @@ def get_merged_split_from_df(self, all_splits, split_keys=['train']):
merged_split = []
for split_key in split_keys:
split = all_splits[split_key]
split = split.dropna().reset_index(drop=True).tolist()
split = split.dropna().reset_index(drop=True).astype(self.slide_data['slide_id'].dtype).tolist()
merged_split.extend(split)

if len(split) > 0:
Expand Down Expand Up @@ -244,7 +244,8 @@ def return_splits(self, from_id=True, csv_path=None):

else:
assert csv_path
all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
#all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
all_splits = pd.read_csv(csv_path, dtype=object)
train_split = self.get_split_from_df(all_splits, 'train')
val_split = self.get_split_from_df(all_splits, 'val')
test_split = self.get_split_from_df(all_splits, 'test')
Expand Down