diff --git a/src/sparcscore/pipeline/workflows.py b/src/sparcscore/pipeline/workflows.py index d84074b..3fc6eb0 100644 --- a/src/sparcscore/pipeline/workflows.py +++ b/src/sparcscore/pipeline/workflows.py @@ -5,6 +5,7 @@ MultithreadedSegmentation, ) from sparcscore.processing.preprocessing import percentile_normalization, downsample_img +from sparcscore.processing.filtering import SizeFilter from sparcscore.processing.utils import visualize_class from sparcscore.processing.segmentation import ( segment_local_threshold, @@ -727,6 +728,33 @@ def _finalize_segmentation_results(self): ).astype(np.uint32) return channels, segmentation + + def get_params_cellsize_filtering(self, type): + + absolute_filter_status = False + + if "min_size" in self.config[f"{type}_segmentation"].keys(): + min_size = self.config[f"{type}_segmentation"]["min_size"] + absolute_filter_status = True + if "max_size" in self.config[f"{type}_segmentation"].keys(): + max_size = self.config[f"{type}_segmentation"]["max_size"] + absolute_filter_status = True + + if absolute_filter_status: + thresholds = [min_size, max_size] + return (thresholds, None) + else: + thresholds = None + + #get confidence intervals to automatically calculate thresholds + if "confidence_interval" in self.config[f"{type}_segmentation"].keys(): + confidence_interval = self.config[f"{type}_segmentation"]["confidence_interval"] + else: + #get default value + self.log(f"No confidence interval specified for {type} mask filtering, using default value of 0.95") + confidence_interval = 0.95 + + return(thresholds, confidence_interval) def cellpose_segmentation(self, input_image): try: @@ -777,7 +805,13 @@ def cellpose_segmentation(self, input_image): self.log(f"GPU Status for segmentation: {use_GPU}") - # check to see if the cells should be filtered within the segmentation run + if "filter_masks_size" in self.config.keys(): + self.filter_size= self.config["filter_status"] + else: + #default behaviour is that it should be turned on (this gives biologically more meaningful results) + self.filter_size = True + + # check to see if the cells should be filtered for matching nuclei/cytosols within the segmentation run if "filter_status" in self.config.keys(): self.filter_status = self.config["filter_status"] else: @@ -851,7 +885,48 @@ def cellpose_segmentation(self, input_image): masks_nucleus_unfiltered = masks_nucleus.copy() masks_cytosol_unfiltered = masks_cytosol.copy() - # add step which automatically removes very small masks/masks that are unconnected (cells must be connected) + ###################### + ### Perform Filtering to remove too small/too large masks if applicable + ###################### + + if self.filter_size: + self.log("Filtering generated nucleus and cytosol masks based on size.") + + #perform filtering for nucleus size + thresholds, confidence_interval = self.get_params_cellsize_filtering("nucleus") + + if thresholds is not None: + self.log(f"Performing filtering of nuclei with specified thresholds {thresholds} from config file.") + else: + self.log(f"Automatically calculating thresholds for filtering of nuclei based on a fitted normal distribution with a confidence interval of {confidence_interval * 100}%.") + + filter_nucleus = SizeFilter( + label="nucleus", + log=True, + plot_qc=self.debug, + directory=self.directory, + confidence_interval = confidence_interval, + filter_threshold = thresholds + ) + masks_nucleus = filter_nucleus.filter(masks_nucleus) + + #perform filtering for cytosol size + thresholds, confidence_interval = self.get_params_cellsize_filtering("cytosol") + + if thresholds is not None: + self.log(f"Performing filtering of cytosols with specified thresholds {thresholds} from config file.") + else: + self.log(f"Automatically calculating thresholds for filtering of cytosols based on a fitted normal distribution with a confidence interval of {confidence_interval * 100}%.") + + filter_cytosol = SizeFilter( + label="cytosol", + log=True, + plot_qc=self.debug, + directory=self.directory, + confidence_interval=confidence_interval, + filter_threshold = thresholds + ) + masks_cytosol = filter_cytosol.filter(masks_cytosol) if not self.filter_status: self.log(