From 17fa81c7ab7cbf5a4262b8946f1fc1748af21976 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Tue, 13 Jun 2023 18:17:45 +0200 Subject: [PATCH 01/22] move model classes to 'training' module + refactor --- .../predictions/functional/array_predictor.py | 2 +- plantseg/predictions/functional/utils.py | 2 +- plantseg/training/__init__.py | 0 plantseg/training/embeddings.py | 78 +++++ plantseg/{models => training}/model.py | 273 +++++++----------- tests/test_model.py | 31 ++ tests/test_model_zoo.py | 2 +- 7 files changed, 215 insertions(+), 173 deletions(-) create mode 100644 plantseg/training/__init__.py create mode 100644 plantseg/training/embeddings.py rename plantseg/{models => training}/model.py (67%) create mode 100644 tests/test_model.py diff --git a/plantseg/predictions/functional/array_predictor.py b/plantseg/predictions/functional/array_predictor.py index a32c777e..4c8104ba 100644 --- a/plantseg/predictions/functional/array_predictor.py +++ b/plantseg/predictions/functional/array_predictor.py @@ -6,7 +6,7 @@ from torch import nn from torch.utils.data import DataLoader, Dataset -from plantseg.models.model import UNet2D +from plantseg.training.model import UNet2D from plantseg.pipeline import gui_logger from plantseg.predictions.functional.array_dataset import ArrayDataset, default_prediction_collate diff --git a/plantseg/predictions/functional/utils.py b/plantseg/predictions/functional/utils.py index f7f93c0f..8f6a35d7 100644 --- a/plantseg/predictions/functional/utils.py +++ b/plantseg/predictions/functional/utils.py @@ -2,7 +2,7 @@ from plantseg import plantseg_global_path, PLANTSEG_MODELS_DIR, home_path from plantseg.augment.transforms import get_test_augmentations -from plantseg.models.model import get_model +from plantseg.training.model import get_model from plantseg.pipeline import gui_logger from plantseg.predictions.functional.array_dataset import ArrayDataset from plantseg.predictions.functional.slice_builder import SliceBuilder diff --git a/plantseg/training/__init__.py b/plantseg/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plantseg/training/embeddings.py b/plantseg/training/embeddings.py new file mode 100644 index 00000000..d73ff3a2 --- /dev/null +++ b/plantseg/training/embeddings.py @@ -0,0 +1,78 @@ +import torch +from torch import nn as nn + + +def shift_tensor(tensor: torch.Tensor, offset: tuple) -> torch.Tensor: + """ Shift a tensor by the given (spatial) offset. + + Args: + tensor: 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor. + Needs to be of float type. + offset: 2d or 3d spatial offset used for shifting the tensor + + Returns: + Shifted tensor of the same shape as the input tensor. + """ + + ndim = len(offset) + assert ndim in (2, 3) + diff = tensor.dim() - ndim + + # don't pad for the first dimensions + # (usually batch and/or channel dimension) + slice_ = diff * [slice(None)] + + # torch padding behaviour is a bit weird. + # we use nn.ReplicationPadND + # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d) + # still, padding needs to be given in the inverse spatial order + + # add padding in inverse spatial order + padding = [] + for off in offset[::-1]: + # if we have a negative offset, we need to shift "to the left", + # which means padding at the right border + # if we have a positive offset, we need to shift "to the right", + # which means padding to the left border + padding.extend([max(0, off), max(0, -off)]) + + # add slicing in the normal spatial order + for off in offset: + if off == 0: + slice_.append(slice(None)) + elif off > 0: + slice_.append(slice(None, -off)) + else: + slice_.append(slice(-off, None)) + + # pad the spatial part of the tensor with replication padding + slice_ = tuple(slice_) + padding = tuple(padding) + padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d + padder = padder(padding) + shifted = padder(tensor) + + # slice the padded tensor to get the spatially shifted tensor + shifted = shifted[slice_] + assert shifted.shape == tensor.shape + + return shifted + + +def invert_offsets(offsets: tuple) -> tuple: + return [[-off for off in offset] for offset in offsets] + + +def embeddings_to_affinities(embeddings: torch.Tensor, offsets: list, delta: float) -> torch.Tensor: + """ Transform embeddings to affinities. + """ + # shift the embeddings by the offsets and stack them along a new axis + # we need to shift in the opposite direction of the offsets, so we invert them + # before applying the shift + offsets_ = invert_offsets(offsets) + shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1) + # substract the embeddings from the shifted embeddings, take the norm and + # transform to affinities based on the delta distance + affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta) + affs = torch.clamp(affs, min=0) ** 2 + return affs diff --git a/plantseg/models/model.py b/plantseg/training/model.py similarity index 67% rename from plantseg/models/model.py rename to plantseg/training/model.py index a89bd8ea..1a9750d5 100644 --- a/plantseg/models/model.py +++ b/plantseg/training/model.py @@ -2,6 +2,7 @@ # https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/model.py import importlib +from typing import List import torch.nn as nn @@ -158,58 +159,6 @@ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr padding=padding, is3d=is3d)) -class ResNetBlock(nn.Module): - """ - Residual block that can be used instead of standard DoubleConv in the Encoder module. - Motivated by: https://arxiv.org/pdf/1706.00120.pdf - - Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): - super(ResNetBlock, self).__init__() - - if in_channels != out_channels: - # conv1x1 for increasing the number of channels - if is3d: - self.conv1 = nn.Conv3d(in_channels, out_channels, 1) - else: - self.conv1 = nn.Conv2d(in_channels, out_channels, 1) - else: - self.conv1 = nn.Identity() - - # residual block - self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, - is3d=is3d) - # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual - n_order = order - for c in 'rel': - n_order = n_order.replace(c, '') - self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, - num_groups=num_groups, is3d=is3d) - - # create non-linearity separately - if 'l' in order: - self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) - elif 'e' in order: - self.non_linearity = nn.ELU(inplace=True) - else: - self.non_linearity = nn.ReLU(inplace=True) - - def forward(self, x): - # apply first convolution to bring the number of channels to out_channels - residual = self.conv1(x) - - # residual block - out = self.conv2(residual) - out = self.conv3(out) - - out += residual - out = self.non_linearity(out) - - return out - - class Encoder(nn.Module): """ A single module from the encoder path consisting of the optional max @@ -225,7 +174,6 @@ class Encoder(nn.Module): apply_pooling (bool): if True use MaxPool3d before DoubleConv pool_kernel_size (int or tuple): the size of the window pool_type (str): pooling layer: 'max' or 'avg' - basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm @@ -233,9 +181,8 @@ class Encoder(nn.Module): is3d (bool): use 3d or 2d convolutions/pooling operation """ - def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, - pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', - num_groups=8, padding=1, is3d=True): + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, pool_kernel_size=2, + pool_type='max', conv_layer_order='gcr', num_groups=8, padding=1, is3d=True): super(Encoder, self).__init__() assert pool_type in ['max', 'avg'] if apply_pooling: @@ -252,13 +199,13 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling= else: self.pooling = None - self.basic_module = basic_module(in_channels, out_channels, - encoder=True, - kernel_size=conv_kernel_size, - order=conv_layer_order, - num_groups=num_groups, - padding=padding, - is3d=is3d) + self.basic_module = DoubleConv(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + is3d=is3d) def forward(self, x): if self.pooling is not None: @@ -280,7 +227,6 @@ class Decoder(nn.Module): scale_factor (tuple): used as the multiplier for the image H/W/D in case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation from the corresponding encoder - basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm @@ -288,37 +234,28 @@ class Decoder(nn.Module): upsample (bool): should the input be upsampled """ - def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, - conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True): + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), conv_layer_order='gcr', + num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True): super(Decoder, self).__init__() if upsample: - if basic_module == DoubleConv: - # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining - self.upsampling = InterpolateUpsampling(mode=mode) - # concat joining - self.joining = partial(self._joining, concat=True) - else: - # if basic_module=ResNetBlock use transposed convolution upsampling and summation joining - self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, - kernel_size=conv_kernel_size, scale_factor=scale_factor) - # sum joining - self.joining = partial(self._joining, concat=False) - # adapt the number of in_channels for the ResNetBlock - in_channels = out_channels + self.upsampling = InterpolateUpsampling(mode=mode) + # concat joining + self.joining = partial(self._joining, concat=True) + else: # no upsampling self.upsampling = NoUpsampling() # concat joining self.joining = partial(self._joining, concat=True) - self.basic_module = basic_module(in_channels, out_channels, - encoder=False, - kernel_size=conv_kernel_size, - order=conv_layer_order, - num_groups=num_groups, - padding=padding, - is3d=is3d) + self.basic_module = DoubleConv(in_channels, out_channels, + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + is3d=is3d) def forward(self, encoder_features, x): x = self.upsampling(encoder_features=encoder_features, x=x) @@ -334,55 +271,35 @@ def _joining(encoder_features, x, concat): return encoder_features + x -def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, - pool_kernel_size, is3d): +def create_encoders(in_channels, f_maps, conv_kernel_size, conv_padding, layer_order, num_groups, pool_kernel_size, + is3d): # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` encoders = [] for i, out_feature_num in enumerate(f_maps): if i == 0: # apply conv_coord only in the first encoder if any - encoder = Encoder(in_channels, out_feature_num, - apply_pooling=False, # skip pooling in the firs encoder - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - padding=conv_padding, - is3d=is3d) + encoder = Encoder(in_channels, out_feature_num, conv_kernel_size=conv_kernel_size, apply_pooling=False, + conv_layer_order=layer_order, num_groups=num_groups, padding=conv_padding, is3d=is3d) else: - encoder = Encoder(f_maps[i - 1], out_feature_num, - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - pool_kernel_size=pool_kernel_size, - padding=conv_padding, - is3d=is3d) + encoder = Encoder(f_maps[i - 1], out_feature_num, conv_kernel_size=conv_kernel_size, + pool_kernel_size=pool_kernel_size, conv_layer_order=layer_order, num_groups=num_groups, + padding=conv_padding, is3d=is3d) encoders.append(encoder) return nn.ModuleList(encoders) -def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, is3d): +def create_decoders(f_maps, conv_kernel_size, conv_padding, layer_order, num_groups, is3d): # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` decoders = [] reversed_f_maps = list(reversed(f_maps)) for i in range(len(reversed_f_maps) - 1): - if basic_module == DoubleConv: - in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] - else: - in_feature_num = reversed_f_maps[i] - + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] out_feature_num = reversed_f_maps[i + 1] - decoder = Decoder(in_feature_num, out_feature_num, - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - padding=conv_padding, - is3d=is3d) + decoder = Decoder(in_feature_num, out_feature_num, conv_kernel_size=conv_kernel_size, + conv_layer_order=layer_order, num_groups=num_groups, padding=conv_padding, is3d=is3d) decoders.append(decoder) return nn.ModuleList(decoders) @@ -421,27 +338,6 @@ def _interpolate(x, size, mode): return F.interpolate(x, size=size, mode=mode) -class TransposeConvUpsampling(AbstractUpsampling): - """ - Args: - in_channels (int): number of input channels for transposed conv - used only if transposed_conv is True - out_channels (int): number of output channels for transpose conv - used only if transposed_conv is True - kernel_size (int or tuple): size of the convolving kernel - used only if transposed_conv is True - scale_factor (int or tuple): stride of the convolution - used only if transposed_conv is True - - """ - - def __init__(self, in_channels=None, out_channels=None, kernel_size=3, scale_factor=(2, 2, 2)): - # make sure that the output size reverses the MaxPool3d from the corresponding encoder - upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, - padding=1) - super().__init__(upsample) - - class NoUpsampling(AbstractUpsampling): def __init__(self): super().__init__(self._no_upsampling) @@ -467,7 +363,6 @@ class AbstractUNet(nn.Module): of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing - basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) layer_order (string): determines the order of layers in `SingleConv` module. E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info num_groups (int): number of groups for the GroupNorm @@ -475,15 +370,14 @@ class AbstractUNet(nn.Module): default: 4 is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied after the final convolution; if False (regression problem) the normalization layer is skipped - conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module + conv_kernel_size (int or tuple): size of the convolving kernel in the conv layer pool_kernel_size (int or tuple): the size of the window conv_padding (int or tuple): add zero-padding added to all three sides of the input is3d (bool): if True the model is 3D, otherwise 2D, default: True """ - def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, - conv_padding=1, is3d=True): + def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=64, layer_order='gcr', num_groups=8, + num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, is3d=True): super(AbstractUNet, self).__init__() if isinstance(f_maps, int): @@ -495,12 +389,11 @@ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_map assert num_groups is not None, "num_groups must be specified if GroupNorm is used" # create encoder path - self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, - num_groups, pool_kernel_size, is3d) + self.encoders = create_encoders(in_channels, f_maps, conv_kernel_size, conv_padding, layer_order, num_groups, + pool_kernel_size, is3d) # create decoder path - self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, - is3d) + self.decoders = create_decoders(f_maps, conv_kernel_size, conv_padding, layer_order, num_groups, is3d) # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels if is3d: @@ -551,22 +444,13 @@ class UNet3D(AbstractUNet): 3DUnet model from `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" `. - - Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder """ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): - super(UNet3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=DoubleConv, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, + super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, + f_maps=f_maps, layer_order=layer_order, num_groups=num_groups, + num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, is3d=True) @@ -578,19 +462,50 @@ class UNet2D(AbstractUNet): def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): - super(UNet2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=DoubleConv, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, + super(UNet2D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, + f_maps=f_maps, layer_order=layer_order, num_groups=num_groups, + num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, is3d=False) +class SpocoNet(nn.Module): + """ + Wrapper around the f-network and the moving average g-network. + """ + + def __init__(self, net_f, net_g, m=0.999, init_equal=True): + super(SpocoNet, self).__init__() + + self.net_f = net_f + self.net_g = net_g + self.m = m + + if init_equal: + # initialize g weights to be equal to f weights + for param_f, param_g in zip(self.net_f.parameters(), self.net_g.parameters()): + param_g.data.copy_(param_f.data) # initialize + param_g.requires_grad = False # freeze g parameters + + @torch.no_grad() + def _momentum_update(self): + """ + Momentum update of the g + """ + for param_f, param_g in zip(self.net_f.parameters(), self.net_g.parameters()): + param_g.data = param_g.data * self.m + param_f.data * (1. - self.m) + + def forward(self, im_f, im_g): + # compute f-embeddings + emb_f = self.net_f(im_f) + + # compute g-embeddings + with torch.no_grad(): # no gradient to g-embeddings + self._momentum_update() # momentum update of g + emb_g = self.net_g(im_g) + + return emb_f, emb_g + + def number_of_features_per_level(init_channel_number, num_levels): return [init_channel_number * 2 ** k for k in range(num_levels)] @@ -605,5 +520,23 @@ def get_class(class_name, modules): def get_model(model_config): - model_class = get_class(model_config['name'], modules=['plantseg.models.model']) + model_class = get_class(model_config['name'], modules=['plantseg.training.model']) return model_class(**model_config) + + +def get_spoco(in_channels: int, out_channels: int, f_maps: List[int], layer_order='bcr') -> SpocoNet: + net_f = UNet2D( + in_channels=in_channels, + out_channels=out_channels, + f_maps=f_maps, + layer_order=layer_order + ) + + net_g = UNet2D( + in_channels=in_channels, + out_channels=out_channels, + f_maps=f_maps, + layer_order=layer_order + ) + + return SpocoNet(net_f, net_g) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 00000000..2cf61d37 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,31 @@ +import torch + +from plantseg.training.embeddings import embeddings_to_affinities +from plantseg.training.model import UNet2D, get_spoco + + +class TestModelPrediction: + def test_UNet2D(self): + model = UNet2D(in_channels=3, out_channels=1) + model.eval() + x = torch.randn(4, 3, 260, 260) + y = model(x) + assert y.shape == (4, 1, 260, 260) + assert torch.all(y >= 0) and torch.all(y <= 1) + + def test_SpocoNet(self): + model = get_spoco(in_channels=1, out_channels=8, f_maps=[16, 32, 64, 128, 256]) + model.eval() + x1 = torch.randn(4, 1, 260, 260) + x2 = torch.rand(4, 1, 260, 260) + y1, y2 = model(x1, x2) + assert y1.shape == (4, 8, 260, 260) + assert y2.shape == (4, 8, 260, 260) + + def test_embeddings_to_affinities(self): + x = torch.randn(4, 8, 128, 128) + offsets = [[-1, 0], [0, -1]] + delta = .5 + affs = embeddings_to_affinities(x, offsets, delta) + assert affs.shape == (4, 2, 128, 128) + assert torch.all(affs >= 0) and torch.all(affs <= 1) diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py index 0fe80c6f..fd66175f 100644 --- a/tests/test_model_zoo.py +++ b/tests/test_model_zoo.py @@ -1,7 +1,7 @@ import pytest import torch -from plantseg.models.model import UNet2D +from plantseg.training.model import UNet2D from plantseg.predictions.functional.utils import get_model_config # test some modes (3D and 2D) From d65c789ca700bf3a0a990473fc5d83734a9026e4 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Thu, 15 Jun 2023 17:32:03 +0200 Subject: [PATCH 02/22] make prediction work with embedding models --- .../predictions/functional/array_predictor.py | 31 +++++++++++++++++-- plantseg/predictions/predict.py | 4 ++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/plantseg/predictions/functional/array_predictor.py b/plantseg/predictions/functional/array_predictor.py index 4c8104ba..7a8f4eac 100644 --- a/plantseg/predictions/functional/array_predictor.py +++ b/plantseg/predictions/functional/array_predictor.py @@ -6,6 +6,7 @@ from torch import nn from torch.utils.data import DataLoader, Dataset +from plantseg.training.embeddings import embeddings_to_affinities from plantseg.training.model import UNet2D from plantseg.pipeline import gui_logger from plantseg.predictions.functional.array_dataset import ArrayDataset, default_prediction_collate @@ -77,11 +78,13 @@ class ArrayPredictor: patch_halo (tuple): mirror padding around the patch single_batch_mode (bool): if True, the batch size will be set to 1 headless (bool): if True, DataParallel will be used if multiple GPUs are available + is_embedding (bool): if True, the model returns embeddings instead of probabilities """ def __init__(self, model: nn.Module, in_channels: int, out_channels: int, device: str, patch: Tuple[int, int, int], - patch_halo: Tuple[int, int, int], single_batch_mode: bool, headless: bool, + patch_halo: Tuple[int, int, int], single_batch_mode: bool, headless: bool, is_embedding: bool = False, verbose_logging: bool = False, disable_tqdm: bool = False): + self.device = device if single_batch_mode: self.batch_size = 1 @@ -101,6 +104,7 @@ def __init__(self, model: nn.Module, in_channels: int, out_channels: int, device self.patch_halo = patch_halo self.verbose_logging = verbose_logging self.disable_tqdm = disable_tqdm + self.is_embedding = is_embedding def __call__(self, test_dataset: Dataset) -> np.ndarray: assert isinstance(test_dataset, ArrayDataset) @@ -113,7 +117,19 @@ def __call__(self, test_dataset: Dataset) -> np.ndarray: # dimensionality of the output predictions volume_shape = self.volume_shape(test_dataset) - prediction_maps_shape = (self.out_channels,) + volume_shape + if self.is_embedding: + if _is_2d_model(self.model): + out_channels = 2 + else: + out_channels = 3 + else: + out_channels = self.out_channels + + if self.is_embedding: + # embeddings will be converted to affinities + prediction_maps_shape = (out_channels,) + volume_shape + else: + prediction_maps_shape = (out_channels,) + volume_shape if self.verbose_logging: gui_logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}') gui_logger.info(f'Using patch_halo: {self.patch_halo}') @@ -145,11 +161,20 @@ def __call__(self, test_dataset: Dataset) -> np.ndarray: prediction = torch.unsqueeze(prediction, dim=-3) else: prediction = self.model(input) + + if self.is_embedding: + if _is_2d_model(self.model): + offsets = [[-1, 0], [0, -1]] + else: + offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] + # convert embeddings to affinities + prediction = embeddings_to_affinities(prediction, offsets, delta=0.5) + # TODO: invert affinities and get the mean across the affinity channels # unpad the prediction prediction = _unpad(prediction, self.patch_halo) # convert to numpy array prediction = prediction.cpu().numpy() - channel_slice = slice(0, self.out_channels) + channel_slice = slice(0, out_channels) # for each batch sample for pred, index in zip(prediction, indices): # add channel dimension to the index diff --git a/plantseg/predictions/predict.py b/plantseg/predictions/predict.py index e74fe0d5..7fc0f555 100755 --- a/plantseg/predictions/predict.py +++ b/plantseg/predictions/predict.py @@ -58,9 +58,11 @@ def __init__(self, input_paths, model_name, patch=(80, 160, 160), stride_ratio=0 model.load_state_dict(state) patch_halo = get_patch_halo(model_name) + is_embedding = not model_config.get('is_segmentation', True) self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'], out_channels=model_config['out_channels'], device=device, patch=self.patch, - patch_halo=patch_halo, single_batch_mode=False, headless=True) + patch_halo=patch_halo, single_batch_mode=False, headless=True, + is_embedding=is_embedding) def process(self, raw: np.ndarray) -> np.ndarray: dataset = get_array_dataset(raw, self.model_name, patch=self.patch, stride_ratio=self.stride_ratio) From b759b510cb871643e4a4900a371f5b3dd35f55f7 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Sat, 17 Jun 2023 13:56:29 +0200 Subject: [PATCH 03/22] average and invert affinities in array predictor --- plantseg/predictions/functional/array_predictor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plantseg/predictions/functional/array_predictor.py b/plantseg/predictions/functional/array_predictor.py index 7a8f4eac..f07b8950 100644 --- a/plantseg/predictions/functional/array_predictor.py +++ b/plantseg/predictions/functional/array_predictor.py @@ -169,7 +169,8 @@ def __call__(self, test_dataset: Dataset) -> np.ndarray: offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] # convert embeddings to affinities prediction = embeddings_to_affinities(prediction, offsets, delta=0.5) - # TODO: invert affinities and get the mean across the affinity channels + # average across channels (i.e. 1-affinities) and invert + prediction = 1 - prediction.mean(dim=1) # unpad the prediction prediction = _unpad(prediction, self.patch_halo) # convert to numpy array From 10dadb95358ab67a7d94e29a97fc2d8497956757 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Sat, 17 Jun 2023 15:44:42 +0200 Subject: [PATCH 04/22] simplify trainer class --- plantseg/training/trainer.py | 204 +++++++++++++++++++++++++++++++++++ plantseg/training/utils.py | 13 +++ 2 files changed, 217 insertions(+) create mode 100644 plantseg/training/trainer.py create mode 100644 plantseg/training/utils.py diff --git a/plantseg/training/trainer.py b/plantseg/training/trainer.py new file mode 100644 index 00000000..83f95895 --- /dev/null +++ b/plantseg/training/trainer.py @@ -0,0 +1,204 @@ +import os +import shutil + +import torch +from torch import nn +from torch.optim.lr_scheduler import ReduceLROnPlateau +from tqdm import tqdm + +from plantseg.pipeline import gui_logger +from plantseg.training.model import UNet2D +from plantseg.training.utils import RunningAverage + + +class UNetTrainer: + """UNet trainer. + + Args: + model (Unet3D): UNet 3D model to be trained + optimizer (nn.optim.Optimizer): optimizer used for training + lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler + WARN: bear in mind that lr_scheduler.step() is invoked after every validation step + (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 + the learning rate will be adjusted after every 30 * validate_after_iters iterations. + loss_criterion (callable): loss function + loaders (dict): 'train' and 'val' loaders + checkpoint_dir (string): dir for saving checkpoints and tensorboard logs + max_num_epochs (int): maximum number of epochs + max_num_iterations (int): maximum number of iterations + validate_after_iters (int): validate after that many iterations + log_after_iters (int): number of iterations before logging to tensorboard + pre_trained(str): path to the pre-trained model + """ + + def __init__(self, model, optimizer, lr_scheduler, loss_criterion, loaders, checkpoint_dir, + max_num_epochs, max_num_iterations, validate_after_iters=500, log_after_iters=100, + pre_trained=None): + + self.model = model + self.optimizer = optimizer + self.scheduler = lr_scheduler + self.loss_criterion = loss_criterion + self.loaders = loaders + self.checkpoint_dir = checkpoint_dir + self.max_num_epochs = max_num_epochs + self.max_num_iterations = max_num_iterations + self.validate_after_iters = validate_after_iters + self.log_after_iters = log_after_iters + self.best_eval_loss = float('+inf') + + self.num_iterations = 1 + if pre_trained is not None: + gui_logger.info(f"Logging pre-trained model from '{pre_trained}'...") + state = torch.load(pre_trained, map_location='cpu') + self.model.load_state_dict(state) + + def train(self): + for epoch in range(self.max_num_epochs): + print(f'Epoch [{epoch}/{self.max_num_epochs}]') + # train for one epoch + should_terminate = self.train_epoch(epoch) + + if should_terminate: + gui_logger.info('Stopping criterion is satisfied. Finishing training') + return + + print('Validating...') + # set the model in eval mode + self.model.eval() + # evaluate on validation set + eval_loss = self.validate() + gui_logger.info(f'Val Loss: {eval_loss}.') + # set the model back to training mode + self.model.train() + + # adjust learning rate if necessary + if isinstance(self.scheduler, ReduceLROnPlateau): + self.scheduler.step(eval_loss) + else: + self.scheduler.step() + # remember best validation metric + is_best = eval_loss < self.best_eval_loss + if is_best: + self.best_eval_loss = eval_loss + + # save checkpoint + self._save_checkpoint(is_best) + + gui_logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") + + def train_epoch(self): + """Trains the model for 1 epoch. + + Returns: + True if the training should be terminated immediately, False otherwise + """ + train_losses = RunningAverage() + + # sets the model in training mode + self.model.train() + + for t in tqdm(self.loaders['train']): + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + + train_losses.update(loss.item(), self._batch_size(input)) + + # compute gradients and update parameters + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + if self.num_iterations % self.log_after_iters == 0: + # log stats, params and images + gui_logger.info(f'Train Loss: {train_losses.avg}.') + + if self.should_stop(): + return True + + self.num_iterations += 1 + + return False + + def should_stop(self): + """ + Training will terminate if maximum number of iterations is exceeded or the learning rate drops below + some predefined threshold (1e-6 in our case) + """ + if self.max_num_iterations < self.num_iterations: + gui_logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.') + return True + + min_lr = 1e-6 + lr = self.optimizer.param_groups[0]['lr'] + if lr < min_lr: + gui_logger.info(f'Learning rate below the minimum {min_lr}.') + return True + + return False + + def validate(self): + val_losses = RunningAverage() + + with torch.no_grad(): + for t in tqdm(self.loaders['val']): + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + val_losses.update(loss.item(), self._batch_size(input)) + + return val_losses.avg + + def _split_training_batch(self, t): + def _move_to_gpu(input): + if isinstance(input, tuple) or isinstance(input, list): + return tuple([_move_to_gpu(x) for x in input]) + else: + if torch.cuda.is_available(): + input = input.cuda(non_blocking=True) + return input + + t = _move_to_gpu(t) + weight = None + if len(t) == 2: + input, target = t + else: + input, target, weight = t + return input, target, weight + + def _forward_pass(self, input, target): + if isinstance(self.model, UNet2D): + # remove the singleton z-dimension from the input + input = torch.squeeze(input, dim=-3) + # forward pass + output = self.model(input) + # add the singleton z-dimension to the output + output = torch.unsqueeze(output, dim=-3) + else: + # forward pass + output = self.model(input) + + loss = self.loss_criterion(output, target) + return output, loss + + def _save_checkpoint(self, is_best): + if isinstance(self.model, nn.DataParallel): + state_dict = self.model.module.state_dict() + else: + state_dict = self.model.state_dict() + + last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch') + gui_logger.info(f"Saving checkpoint to '{last_file_path}'") + + torch.save(state_dict, last_file_path) + if is_best: + best_file_path = os.path.join(self.checkpoint_dir, 'best_checkpoint.pytorch') + shutil.copyfile(last_file_path, best_file_path) + + @staticmethod + def _batch_size(input): + if isinstance(input, list) or isinstance(input, tuple): + return input[0].size(0) + else: + return input.size(0) diff --git a/plantseg/training/utils.py b/plantseg/training/utils.py new file mode 100644 index 00000000..6354b172 --- /dev/null +++ b/plantseg/training/utils.py @@ -0,0 +1,13 @@ +class RunningAverage: + """Computes and stores the average + """ + + def __init__(self): + self.count = 0 + self.sum = 0 + self.avg = 0 + + def update(self, value, n=1): + self.count += n + self.sum += value * n + self.avg = self.sum / self.count From 63d3337b9262c5cc66f2d18904ec8cc901d68700 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Mon, 19 Jun 2023 16:04:54 +0200 Subject: [PATCH 05/22] add simple training widget --- plantseg/training/train.py | 2 + plantseg/training/trainer.py | 53 +++++++-------------- plantseg/viewer/containers.py | 8 ++++ plantseg/viewer/viewer.py | 3 +- plantseg/viewer/widget/training.py | 74 ++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 37 deletions(-) create mode 100644 plantseg/training/train.py create mode 100644 plantseg/viewer/widget/training.py diff --git a/plantseg/training/train.py b/plantseg/training/train.py new file mode 100644 index 00000000..ee7eb373 --- /dev/null +++ b/plantseg/training/train.py @@ -0,0 +1,2 @@ +def unet_training(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs): + pass diff --git a/plantseg/training/trainer.py b/plantseg/training/trainer.py index 83f95895..78a9ba74 100644 --- a/plantseg/training/trainer.py +++ b/plantseg/training/trainer.py @@ -1,8 +1,10 @@ import os import shutil +from typing import Tuple import torch from torch import nn +import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau from tqdm import tqdm @@ -17,23 +19,20 @@ class UNetTrainer: Args: model (Unet3D): UNet 3D model to be trained optimizer (nn.optim.Optimizer): optimizer used for training - lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler - WARN: bear in mind that lr_scheduler.step() is invoked after every validation step - (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 - the learning rate will be adjusted after every 30 * validate_after_iters iterations. - loss_criterion (callable): loss function + lr_scheduler (torch.optim.lr_scheduler.LRScheduler): learning rate scheduler + loss_criterion (nn.Module): loss function loaders (dict): 'train' and 'val' loaders checkpoint_dir (string): dir for saving checkpoints and tensorboard logs max_num_epochs (int): maximum number of epochs max_num_iterations (int): maximum number of iterations - validate_after_iters (int): validate after that many iterations + device (str): device to use for training log_after_iters (int): number of iterations before logging to tensorboard pre_trained(str): path to the pre-trained model """ - def __init__(self, model, optimizer, lr_scheduler, loss_criterion, loaders, checkpoint_dir, - max_num_epochs, max_num_iterations, validate_after_iters=500, log_after_iters=100, - pre_trained=None): + def __init__(self, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler.LRScheduler, + loss_criterion: nn.Module, loaders: dict, checkpoint_dir: str, max_num_epochs: int, + max_num_iterations: int, device: str = 'cuda', log_after_iters: int = 100, pre_trained: int = None): self.model = model self.optimizer = optimizer @@ -43,7 +42,7 @@ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, loaders, chec self.checkpoint_dir = checkpoint_dir self.max_num_epochs = max_num_epochs self.max_num_iterations = max_num_iterations - self.validate_after_iters = validate_after_iters + self.device = device self.log_after_iters = log_after_iters self.best_eval_loss = float('+inf') @@ -57,7 +56,7 @@ def train(self): for epoch in range(self.max_num_epochs): print(f'Epoch [{epoch}/{self.max_num_epochs}]') # train for one epoch - should_terminate = self.train_epoch(epoch) + should_terminate = self.train_epoch() if should_terminate: gui_logger.info('Stopping criterion is satisfied. Finishing training') @@ -98,10 +97,9 @@ def train_epoch(self): # sets the model in training mode self.model.train() - for t in tqdm(self.loaders['train']): - input, target, weight = self._split_training_batch(t) - - output, loss = self._forward_pass(input, target, weight) + for input, target in tqdm(self.loaders['train']): + input, target = input.to(self.device), target.to(self.device) + output, loss = self._forward_pass(input, target) train_losses.update(loss.item(), self._batch_size(input)) @@ -142,32 +140,15 @@ def validate(self): val_losses = RunningAverage() with torch.no_grad(): - for t in tqdm(self.loaders['val']): - input, target, weight = self._split_training_batch(t) + for input, target in tqdm(self.loaders['val']): + input, target = input.to(self.device), target.to(self.device) - output, loss = self._forward_pass(input, target, weight) + output, loss = self._forward_pass(input, target) val_losses.update(loss.item(), self._batch_size(input)) return val_losses.avg - def _split_training_batch(self, t): - def _move_to_gpu(input): - if isinstance(input, tuple) or isinstance(input, list): - return tuple([_move_to_gpu(x) for x in input]) - else: - if torch.cuda.is_available(): - input = input.cuda(non_blocking=True) - return input - - t = _move_to_gpu(t) - weight = None - if len(t) == 2: - input, target = t - else: - input, target, weight = t - return input, target, weight - - def _forward_pass(self, input, target): + def _forward_pass(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(self.model, UNet2D): # remove the singleton z-dimension from the input input = torch.squeeze(input, dim=-3) diff --git a/plantseg/viewer/containers.py b/plantseg/viewer/containers.py index 3c1bb1b5..c8169ae2 100644 --- a/plantseg/viewer/containers.py +++ b/plantseg/viewer/containers.py @@ -15,6 +15,7 @@ from plantseg.viewer.widget.segmentation import widget_fix_over_under_segmentation_from_nuclei from plantseg.viewer.widget.segmentation import widget_lifted_multicut from plantseg.viewer.widget.segmentation import widget_simple_dt_ws +from plantseg.viewer.widget.training import widget_unet_training def setup_menu(container, path=None): @@ -62,6 +63,13 @@ def get_gasp_workflow(): return container +def get_training_workflow(): + container = MainWindow(widgets=[widget_unet_training], + labels=False) + container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/UNet-Training') + return container + + def get_extra_seg(): container = MainWindow(widgets=[widget_dt_ws, widget_lifted_multicut, diff --git a/plantseg/viewer/viewer.py b/plantseg/viewer/viewer.py index 8e271a09..325dce5a 100644 --- a/plantseg/viewer/viewer.py +++ b/plantseg/viewer/viewer.py @@ -1,6 +1,6 @@ import napari -from plantseg.viewer.containers import get_extra_seg, get_extra_pred +from plantseg.viewer.containers import get_extra_seg, get_extra_pred, get_training_workflow from plantseg.viewer.containers import get_gasp_workflow, get_preprocessing_workflow, get_main from plantseg.viewer.logging import napari_formatted_logging from plantseg.viewer.widget.proofreading.proofreading import setup_proofreading_keybindings @@ -15,6 +15,7 @@ def run_viewer(): for _containers, name in [(get_preprocessing_workflow(), 'Data - Processing'), (get_gasp_workflow(), 'UNet + Segmentation'), + (get_training_workflow(), 'Training'), (get_extra_pred(), 'Extra-Pred'), (get_extra_seg(), 'Extra-Seg'), ]: diff --git a/plantseg/viewer/widget/training.py b/plantseg/viewer/widget/training.py new file mode 100644 index 00000000..e3125ff1 --- /dev/null +++ b/plantseg/viewer/widget/training.py @@ -0,0 +1,74 @@ +from concurrent.futures import Future +from pathlib import Path +from typing import Tuple + +from magicgui import magicgui +from napari.types import LayerDataTuple + +from plantseg import PLANTSEG_MODELS_DIR +from plantseg.training.train import unet_training +from plantseg.utils import list_all_dimensionality +from plantseg.viewer.widget.predictions import ALL_DEVICES +from plantseg.viewer.widget.utils import create_layer_name, start_threading_process, return_value_if_widget + + +def unet_training_wrapper(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs): + """ + Wrapper to run unet_training in a thread_worker, this is needed to allow the user to select the device + in the headless mode. + """ + return unet_training(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs) + + +@magicgui(call_button='Run Training', + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing train and val subfolders'}, + model_name={'label': 'Trained model name', + 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + dimensionality={'label': 'Dimensionality', + 'tooltip': 'Dimensionality of the data (2D or 3D). ', + 'widget_type': 'ComboBox', + 'choices': list_all_dimensionality()}, + sparse={'label': 'Sparse', + 'tooltip': 'If True, SPOCO spare training algorithm will be used', + 'widget_type': 'CheckBox'}, + device={'label': 'Device', + 'choices': ALL_DEVICES} + ) +def widget_unet_training(dataset_dir: Path = Path.home(), + model_name: str = 'my-model', + dimensionality: str = '3D', + patch_size: Tuple[int, int, int] = (80, 160, 160), + sparse: bool = False, + device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: + out_name = create_layer_name(model_name, 'training') + step_kwargs = dict(model_name=model_name, sparse=sparse, dimensionality=dimensionality) + return start_threading_process(unet_training_wrapper, + runtime_kwargs={ + 'dataset_dir': dataset_dir, + 'model_name': model_name, + 'patch_size': patch_size, + 'dimensionality': dimensionality, + 'sparse': sparse, + 'device': device + }, + step_name='UNet training', + widgets_to_update=[], + input_keys=(model_name, 'training'), + out_name=out_name, + layer_kwarg={'name': out_name}, + layer_type='image', + statics_kwargs=step_kwargs + ) + + +@widget_unet_training.dimensionality.changed.connect +def _on_dimensionality_changed(dimensionality: str): + dimensionality = return_value_if_widget(dimensionality) + if dimensionality == '2D': + patch_size = (1, 256, 256) + else: + patch_size = (80, 160, 160) + + widget_unet_training.patch_size.value = patch_size From ef0d80dc2d4f82a93f509ebe7ae9dcb53c9c4357 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Mon, 19 Jun 2023 21:55:59 +0200 Subject: [PATCH 06/22] add training widget --- .../predictions/functional/predictions.py | 2 +- .../predictions/functional/slice_builder.py | 46 +- plantseg/predictions/functional/utils.py | 2 +- plantseg/training/augs.py | 553 ++++++++++++++++++ plantseg/training/h5dataset.py | 109 ++++ plantseg/training/losses.py | 98 ++++ plantseg/training/train.py | 82 ++- plantseg/training/trainer.py | 1 + plantseg/viewer/widget/training.py | 25 +- 9 files changed, 901 insertions(+), 17 deletions(-) create mode 100644 plantseg/training/augs.py create mode 100644 plantseg/training/h5dataset.py create mode 100644 plantseg/training/losses.py diff --git a/plantseg/predictions/functional/predictions.py b/plantseg/predictions/functional/predictions.py index 479b1695..77b98dd4 100644 --- a/plantseg/predictions/functional/predictions.py +++ b/plantseg/predictions/functional/predictions.py @@ -47,7 +47,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] raw = raw.astype('float32') stride = get_stride_shape(patch) augs = get_test_augmentations(raw) - slice_builder = SliceBuilder(raw, label_dataset=None, weight_dataset=None, patch_shape=patch, stride_shape=stride) + slice_builder = SliceBuilder(raw, label_dataset=None, patch_shape=patch, stride_shape=stride) test_dataset = ArrayDataset(raw, slice_builder, augs, verbose_logging=False) pmaps = predictor(test_dataset) diff --git a/plantseg/predictions/functional/slice_builder.py b/plantseg/predictions/functional/slice_builder.py index 64ac156c..b1044188 100644 --- a/plantseg/predictions/functional/slice_builder.py +++ b/plantseg/predictions/functional/slice_builder.py @@ -1,3 +1,6 @@ +import numpy as np + + class SliceBuilder: """ Builds the position of the patches in a given raw/label/weight ndarray based on the patch and stride shape. @@ -5,12 +8,11 @@ class SliceBuilder: Args: raw_dataset (ndarray): raw data label_dataset (ndarray): ground truth labels - weight_dataset (ndarray): weights for the labels patch_shape (tuple): the shape of the patch DxHxW stride_shape (tuple): the shape of the stride DxHxW """ - def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape): + def __init__(self, raw_dataset, label_dataset, patch_shape, stride_shape): patch_shape = tuple(patch_shape) stride_shape = tuple(stride_shape) self._check_patch_shape(patch_shape) @@ -22,11 +24,6 @@ def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stri # take the first element in the label_dataset to build slices self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape) assert len(self._raw_slices) == len(self._label_slices) - if weight_dataset is None: - self._weight_slices = None - else: - self._weight_slices = self._build_slices(weight_dataset, patch_shape, stride_shape) - assert len(self.raw_slices) == len(self._weight_slices) @property def raw_slices(self): @@ -36,10 +33,6 @@ def raw_slices(self): def label_slices(self): return self._label_slices - @property - def weight_slices(self): - return self._weight_slices - @staticmethod def _build_slices(dataset, patch_shape, stride_shape): """Iterates over a given n-dim dataset patch-by-patch with a given stride @@ -86,3 +79,34 @@ def _gen_indices(i, k, s): def _check_patch_shape(patch_shape): assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple' assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64' + + +class FilterSliceBuilder(SliceBuilder): + """ + Filter patches containing less than `threshold` non-zero labels. + """ + + def __init__(self, raw_dataset, label_dataset, patch_shape, stride_shape, ignore_index=(0,), + threshold=0.1, slack_acceptance=0.01): + super().__init__(raw_dataset, label_dataset, patch_shape, stride_shape) + if label_dataset is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + label_idx = raw_label_idx[1] + patch = np.copy(label_dataset[label_idx]) + for ii in ignore_index: + patch[patch == ii] = 0 + non_ignore_counts = np.count_nonzero(patch != 0) + non_ignore_counts = non_ignore_counts / patch.size + return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) diff --git a/plantseg/predictions/functional/utils.py b/plantseg/predictions/functional/utils.py index 8f6a35d7..2ac5f1da 100644 --- a/plantseg/predictions/functional/utils.py +++ b/plantseg/predictions/functional/utils.py @@ -44,7 +44,7 @@ def get_array_dataset(raw, model_name, patch, stride_ratio, global_normalization augs = get_test_augmentations(None) stride = get_stride_shape(patch, stride_ratio) - slice_builder = SliceBuilder(raw, label_dataset=None, weight_dataset=None, patch_shape=patch, stride_shape=stride) + slice_builder = SliceBuilder(raw, label_dataset=None, patch_shape=patch, stride_shape=stride) return ArrayDataset(raw, slice_builder, augs, verbose_logging=False) diff --git a/plantseg/training/augs.py b/plantseg/training/augs.py new file mode 100644 index 00000000..04e299be --- /dev/null +++ b/plantseg/training/augs.py @@ -0,0 +1,553 @@ +import random + +import numpy as np +import torch +from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve +from skimage import measure +from skimage.filters import gaussian +from skimage.segmentation import find_boundaries + +# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g. +GLOBAL_RANDOM_STATE = np.random.RandomState(47) + + +# copied from https://github.com/wolny/pytorch-3dunet +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, m): + for t in self.transforms: + m = t(m) + return m + + +class RandomFlip: + """ + Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + """ + + def __init__(self, random_state, axis_prob=0.5, **kwargs): + assert random_state is not None, 'RandomState cannot be None' + self.random_state = random_state + self.axes = (0, 1, 2) + self.axis_prob = axis_prob + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + for axis in self.axes: + if self.random_state.uniform() > self.axis_prob: + if m.ndim == 3: + m = np.flip(m, axis) + else: + channels = [np.flip(m[c], axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate90: + """ + Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + + IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis) + """ + + def __init__(self, random_state, **kwargs): + self.random_state = random_state + # always rotate around z-axis + self.axis = (1, 2) + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + # pick number of rotations at random + k = self.random_state.randint(0, 4) + # rotate k times around a given plane + if m.ndim == 3: + m = np.rot90(m, k, self.axis) + else: + channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate: + """ + Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval. + Rotation axis is picked at random from the list of provided axes. + """ + + def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs): + if axes is None: + axes = [(1, 0), (2, 1), (2, 0)] + else: + assert isinstance(axes, list) and len(axes) > 0 + + self.random_state = random_state + self.angle_spectrum = angle_spectrum + self.axes = axes + self.mode = mode + self.order = order + + def __call__(self, m): + axis = self.axes[self.random_state.randint(len(self.axes))] + angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum) + + if m.ndim == 3: + m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) + else: + channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c + in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomContrast: + """ + Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`. + """ + + def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs): + self.random_state = random_state + assert len(alpha) == 2 + self.alpha = alpha + self.mean = mean + self.execution_probability = execution_probability + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + alpha = self.random_state.uniform(self.alpha[0], self.alpha[1]) + result = self.mean + alpha * (m - self.mean) + return np.clip(result, -1, 1) + + return m + + +# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader +# remember to use spline_order=0 when transforming the labels +class ElasticDeformation: + """ + Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D). + Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62 + """ + + def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True, + **kwargs): + """ + :param spline_order: the order of spline interpolation (use 0 for labeled images) + :param alpha: scaling factor for deformations + :param sigma: smoothing factor for Gaussian filter + :param execution_probability: probability of executing this transform + :param apply_3d: if True apply deformations in each axis + """ + self.random_state = random_state + self.spline_order = spline_order + self.alpha = alpha + self.sigma = sigma + self.execution_probability = execution_probability + self.apply_3d = apply_3d + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + assert m.ndim in [3, 4] + + if m.ndim == 3: + volume_shape = m.shape + else: + volume_shape = m[0].shape + + if self.apply_3d: + dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha + else: + dz = np.zeros_like(m) + + dy, dx = [ + gaussian_filter( + self.random_state.randn(*volume_shape), + self.sigma, mode="reflect" + ) * self.alpha for _ in range(2) + ] + + z_dim, y_dim, x_dim = volume_shape + z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij') + indices = z + dz, y + dy, x + dx + + if m.ndim == 3: + return map_coordinates(m, indices, order=self.spline_order, mode='reflect') + else: + channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m] + return np.stack(channels, axis=0) + + return m + + +class CropToFixed: + def __init__(self, random_state, size=(256, 256), centered=False, **kwargs): + self.random_state = random_state + self.crop_y, self.crop_x = size + self.centered = centered + + def __call__(self, m): + def _padding(pad_total): + half_total = pad_total // 2 + return (half_total, pad_total - half_total) + + def _rand_range_and_pad(crop_size, max_size): + """ + Returns a tuple: + max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)` + pad (int): padding in both directions; if crop_size is lt max_size the pad is 0 + """ + if crop_size < max_size: + return max_size - crop_size, (0, 0) + else: + return 1, _padding(crop_size - max_size) + + def _start_and_pad(crop_size, max_size): + if crop_size < max_size: + return (max_size - crop_size) // 2, (0, 0) + else: + return 0, _padding(crop_size - max_size) + + assert m.ndim in (3, 4) + if m.ndim == 3: + _, y, x = m.shape + else: + _, _, y, x = m.shape + + if not self.centered: + y_range, y_pad = _rand_range_and_pad(self.crop_y, y) + x_range, x_pad = _rand_range_and_pad(self.crop_x, x) + + y_start = self.random_state.randint(y_range) + x_start = self.random_state.randint(x_range) + + else: + y_start, y_pad = _start_and_pad(self.crop_y, y) + x_start, x_pad = _start_and_pad(self.crop_x, x) + + if m.ndim == 3: + result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect') + else: + channels = [] + for c in range(m.shape[0]): + result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')) + return np.stack(channels, axis=0) + + +class AbstractLabelToBoundary: + AXES_TRANSPOSE = [ + (0, 1, 2), # X + (0, 2, 1), # Y + (2, 0, 1) # Z + ] + + def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs): + """ + :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index + will be restored where is was in the patch originally + :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes + :param append_label: if True append the orignal ground truth labels to the last channel + :param blur: Gaussian blur the boundaries + :param sigma: standard deviation for Gaussian kernel + """ + self.ignore_index = ignore_index + self.aggregate_affinities = aggregate_affinities + self.append_label = append_label + + def __call__(self, m): + """ + Extract boundaries from a given 3D label tensor. + :param m: input 3D tensor + :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background + """ + assert m.ndim == 3 + + kernels = self.get_kernels() + boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels] + channels = np.stack(boundary_arr) + results = [] + if self.aggregate_affinities: + assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes" + # aggregate affinities with the same offset + for i in range(0, len(kernels), 3): + # merge across X,Y,Z axes (logical OR) + xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int32) + # recover ignore index + xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index) + results.append(xyz_aggregated_affinities) + else: + results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])] + + if self.append_label: + # append original input data + results.append(m) + + # stack across channel dim + return np.stack(results, axis=0) + + @staticmethod + def create_kernel(axis, offset): + # create conv kernel + k_size = offset + 1 + k = np.zeros((1, 1, k_size), dtype=np.int32) + k[0, 0, 0] = 1 + k[0, 0, offset] = -1 + return np.transpose(k, axis) + + def get_kernels(self): + raise NotImplementedError + + +class StandardLabelToBoundary: + def __init__(self, ignore_index=None, append_label=False, mode='thick', foreground=False, + **kwargs): + self.ignore_index = ignore_index + self.append_label = append_label + self.mode = mode + self.foreground = foreground + + def __call__(self, m): + assert m.ndim == 3 + + boundaries = find_boundaries(m, connectivity=2, mode=self.mode) + boundaries = boundaries.astype('int32') + + results = [] + if self.foreground: + foreground = (m > 0).astype('uint8') + results.append(_recover_ignore_index(foreground, m, self.ignore_index)) + + results.append(_recover_ignore_index(boundaries, m, self.ignore_index)) + + if self.append_label: + # append original input data + results.append(m) + + return np.stack(results, axis=0) + + +class Standardize: + """ + Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. + """ + + def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): + if mean is not None or std is not None: + assert mean is not None and std is not None + self.mean = mean + self.std = std + self.eps = eps + self.channelwise = channelwise + + def __call__(self, m): + if self.mean is not None: + mean, std = self.mean, self.std + else: + if self.channelwise: + # normalize per-channel + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + mean = np.mean(m, axis=axes, keepdims=True) + std = np.std(m, axis=axes, keepdims=True) + else: + mean = np.mean(m) + std = np.std(m) + + return (m - mean) / np.clip(std, a_min=self.eps, a_max=None) + + +class PercentileNormalizer: + def __init__(self, pmin=1, pmax=99.6, channelwise=False, eps=1e-10, **kwargs): + self.eps = eps + self.pmin = pmin + self.pmax = pmax + self.channelwise = channelwise + + def __call__(self, m): + if self.channelwise: + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True) + pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True) + else: + pmin = np.percentile(m, self.pmin) + pmax = np.percentile(m, self.pmax) + + return (m - pmin) / (pmax - pmin + self.eps) + + +class Normalize: + """ + Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1]. + """ + + def __init__(self, min_value, max_value, **kwargs): + assert max_value > min_value + self.min_value = min_value + self.value_range = max_value - min_value + + def __call__(self, m): + norm_0_1 = (m - self.min_value) / self.value_range + return np.clip(2 * norm_0_1 - 1, -1, 1) + + +class AdditiveGaussianNoise: + def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.scale = scale + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + std = self.random_state.uniform(self.scale[0], self.scale[1]) + gaussian_noise = self.random_state.normal(0, std, size=m.shape) + return m + gaussian_noise + return m + + +class AdditivePoissonNoise: + def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.lam = lam + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + lam = self.random_state.uniform(self.lam[0], self.lam[1]) + poisson_noise = self.random_state.poisson(lam, size=m.shape) + return m + poisson_noise + return m + + +class ToTensor: + """ + Converts a given input numpy.ndarray into torch.Tensor. + + Args: + expand_dims (bool): if True, adds a channel dimension to the input data + dtype (np.dtype): the desired output data type + """ + + def __init__(self, expand_dims, dtype=np.float32, **kwargs): + self.expand_dims = expand_dims + self.dtype = dtype + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + # add channel dimension + if self.expand_dims and m.ndim == 3: + m = np.expand_dims(m, axis=0) + + return torch.from_numpy(m.astype(dtype=self.dtype)) + + +class Relabel: + """ + Relabel a numpy array of labels into a consecutive numbers, e.g. + [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume + at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder. + """ + + def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs): + self.append_original = append_original + self.ignore_label = ignore_label + self.run_cc = run_cc + + if ignore_label is not None: + assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region" + + def __call__(self, m): + orig = m + if self.run_cc: + # assign 0 to the ignore region + m = measure.label(m, background=self.ignore_label) + + _, unique_labels = np.unique(m, return_inverse=True) + result = unique_labels.reshape(m.shape) + if self.append_original: + result = np.stack([result, orig]) + return result + + +class Identity: + def __init__(self, **kwargs): + pass + + def __call__(self, m): + return m + + +class RgbToLabel: + def __call__(self, img): + img = np.array(img) + assert img.ndim == 3 and img.shape[2] == 3 + result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2] + return result + + +class LabelToTensor: + def __call__(self, m): + m = np.array(m) + return torch.from_numpy(m.astype(dtype='int64')) + + +class GaussianBlur3D: + def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs): + self.sigma = sigma + self.execution_probability = execution_probability + + def __call__(self, x): + if random.random() < self.execution_probability: + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = gaussian(x, sigma=sigma) + return x + return x + + +def _recover_ignore_index(input, orig, ignore_index): + if ignore_index is not None: + mask = orig == ignore_index + input[mask] = ignore_index + + return input + + +class Augmenter: + def __init__(self): + self.seed = GLOBAL_RANDOM_STATE.randint(10000000) + + def raw_transform(self, stats): + return Compose([ + Standardize(stats['mean'], stats['std']), + RandomFlip(np.random.RandomState(self.seed)), + RandomRotate90(np.random.RandomState(self.seed)), + RandomRotate(np.random.RandomState(self.seed), axes=[[2, 1]], angle_spectrum=45, mode='reflect'), + GaussianBlur3D(), + AdditiveGaussianNoise(np.random.RandomState()), + AdditivePoissonNoise(np.random.RandomState()), + ToTensor(expand_dims=True) + ]) + + def label_transform(self): + return Compose([ + RandomFlip(np.random.RandomState(self.seed)), + RandomRotate90(np.random.RandomState(self.seed)), + RandomRotate(np.random.RandomState(self.seed), axes=[[2, 1]], angle_spectrum=45, mode='reflect'), + StandardLabelToBoundary(), + ToTensor(expand_dims=False) + ]) diff --git a/plantseg/training/h5dataset.py b/plantseg/training/h5dataset.py new file mode 100644 index 00000000..e365a12c --- /dev/null +++ b/plantseg/training/h5dataset.py @@ -0,0 +1,109 @@ +import h5py +import numpy as np +from torch.utils.data import Dataset + +from plantseg.pipeline import gui_logger +from plantseg.predictions.functional.slice_builder import FilterSliceBuilder + + +# copied from https://github.com/wolny/pytorch-3dunet +class HDF5Dataset(Dataset): + """ + Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets + patch by patch with a given stride. + + Args: + file_path (str): path to H5 file containing raw data as well as labels and per pixel weights (optional) + augmenter (transforms.Augmenter): list of augmentations to be applied to the raw and label data sets + patch_shape (tuple): shape of the patch to be extracted from the raw data set + raw_internal_path (str or list): H5 internal path to the raw dataset + label_internal_path (str or list): H5 internal path to the label dataset + global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset + """ + + def __init__(self, file_path, augmenter, patch_shape, + raw_internal_path='raw', label_internal_path='label', global_normalization=True): + + self.file_path = file_path + + with h5py.File(file_path, 'r') as f: + self.raw = self.load_dataset(f, raw_internal_path) + stats = calculate_stats(self.raw, global_normalization) + self.augmenter = augmenter + self.raw_transform = self.augmenter.raw_transform(stats) + + # create label/weight transform only in train/val phase + self.label_transform = self.augmenter.label_transform() + self.label = self.load_dataset(f, label_internal_path) + self._check_volume_sizes(self.raw, self.label) + + # build slice indices for raw and label data sets + slice_builder = FilterSliceBuilder(self.raw, self.label, patch_shape=patch_shape, + stride_shape=tuple(i // 2 for i in patch_shape)) + self.raw_slices = slice_builder.raw_slices + self.label_slices = slice_builder.label_slices + + self.patch_count = len(self.raw_slices) + gui_logger.info(f'{self.patch_count} patches found in {file_path}') + + @staticmethod + def load_dataset(input_file, internal_path): + ds = input_file[internal_path][:] + assert ds.ndim in [3, 4], \ + f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)" + return ds + + def __getitem__(self, idx): + if idx >= len(self): + raise StopIteration + + # get the slice for a given index 'idx' + raw_idx = self.raw_slices[idx] + # get the raw data patch for a given slice + raw_patch_transformed = self.raw_transform(self.raw[raw_idx]) + + # get the slice for a given index 'idx' + label_idx = self.label_slices[idx] + label_patch_transformed = self.label_transform(self.label[label_idx]) + # return the transformed raw and label patches + return raw_patch_transformed, label_patch_transformed + + def __len__(self): + return self.patch_count + + @staticmethod + def create_h5_file(file_path): + raise NotImplementedError + + @staticmethod + def _check_volume_sizes(raw, label): + def _volume_shape(volume): + if volume.ndim == 3: + return volume.shape + return volume.shape[1:] + + assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)' + assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)' + + assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size' + + +def calculate_stats(images, global_normalization=True): + """ + Calculates min, max, mean, std given a list of nd-arrays + """ + if global_normalization: + # flatten first since the images might not be the same size + flat = np.concatenate( + [img.ravel() for img in images] + ) + pmin, pmax, mean, std = np.percentile(flat, 1), np.percentile(flat, 99.6), np.mean(flat), np.std(flat) + else: + pmin, pmax, mean, std = None, None, None, None + + return { + 'pmin': pmin, + 'pmax': pmax, + 'mean': mean, + 'std': std + } diff --git a/plantseg/training/losses.py b/plantseg/training/losses.py new file mode 100644 index 00000000..182c2f2a --- /dev/null +++ b/plantseg/training/losses.py @@ -0,0 +1,98 @@ +"""""" + +import torch +from torch import nn + + +# copied from https://github.com/wolny/pytorch-3dunet +def flatten(tensor): + """Flattens a given tensor such that the channel axis is first. + The shapes are transformed as follows: + (N, C, D, H, W) -> (C, N * D * H * W) + """ + # number of channels + C = tensor.size(1) + # new axis order + axis_order = (1, 0) + tuple(range(2, tensor.dim())) + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) + transposed = tensor.permute(axis_order) + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) + return transposed.contiguous().view(C, -1) + + +def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): + """ + Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. + Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. + + Args: + input (torch.Tensor): NxCxSpatial input tensor + target (torch.Tensor): NxCxSpatial target tensor + epsilon (float): prevents division by zero + weight (torch.Tensor): Cx1 tensor of weight per channel/class + """ + + # input and target shapes must match + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + # compute per channel Dice Coefficient + intersect = (input * target).sum(-1) + if weight is not None: + intersect = weight * intersect + + # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) + denominator = (input * input).sum(-1) + (target * target).sum(-1) + return 2 * (intersect / denominator.clamp(min=epsilon)) + + +class _AbstractDiceLoss(nn.Module): + """ + Base class for different implementations of Dice loss. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super(_AbstractDiceLoss, self).__init__() + self.register_buffer('weight', weight) + # The output from the network during training is assumed to be un-normalized probabilities and we would + # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, + # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. + # However if one would like to apply Softmax in order to get the proper probability distribution from the + # output, just specify `normalization=Softmax` + assert normalization in ['sigmoid', 'softmax', 'none'] + if normalization == 'sigmoid': + self.normalization = nn.Sigmoid() + elif normalization == 'softmax': + self.normalization = nn.Softmax(dim=1) + else: + self.normalization = lambda x: x + + def dice(self, input, target, weight): + # actual Dice score computation; to be implemented by the subclass + raise NotImplementedError + + def forward(self, input, target): + # get probabilities from logits + input = self.normalization(input) + + # compute per channel Dice coefficient + per_channel_dice = self.dice(input, target, weight=self.weight) + + # average Dice score across all channels/classes + return 1. - torch.mean(per_channel_dice) + + +class DiceLoss(_AbstractDiceLoss): + """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. + For multi-class segmentation `weight` parameter can be used to assign different weights per class. + The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super().__init__(weight, normalization) + + def dice(self, input, target, weight): + return compute_per_channel_dice(input, target, weight=self.weight) diff --git a/plantseg/training/train.py b/plantseg/training/train.py index ee7eb373..ffdad81f 100644 --- a/plantseg/training/train.py +++ b/plantseg/training/train.py @@ -1,2 +1,80 @@ -def unet_training(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs): - pass +import glob +import os +from itertools import chain +from typing import Tuple + +import torch +from PIL import Image +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader, ConcatDataset + +from plantseg import PLANTSEG_MODELS_DIR +from plantseg.training.augs import Augmenter +from plantseg.training.h5dataset import HDF5Dataset +from plantseg.training.losses import DiceLoss +from plantseg.training.model import UNet2D, UNet3D +from plantseg.training.trainer import UNetTrainer + + +def unet_training(dataset_dir: str, model_name: str, in_channels: int, out_channels: int, + patch_size: Tuple[int, int, int], dimensionality: str, + sparse: bool, device: str, **kwargs) -> Image: + # create loaders + train_datasets = create_datasets(dataset_dir, 'train', patch_size) + val_datasets = create_datasets(dataset_dir, 'val', patch_size) + loaders = { + 'train': DataLoader(ConcatDataset(train_datasets), batch_size=1, shuffle=True, pin_memory=True, + num_workers=4), + # don't shuffle during validation: useful when showing how predictions for a given batch get better over time + 'val': DataLoader(ConcatDataset(val_datasets), batch_size=1, shuffle=False, pin_memory=True, + num_workers=1) + } + + # create model + # set final activation to sigmoid if not sparse (i.e. not embedding model) + final_sigmoid = not sparse + if dimensionality == '2D': + model = UNet2D(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid) + else: + model = UNet3D(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid) + model = model.to(device) + + # create optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) + + # create trainer + trainer = UNetTrainer( + model=model, + optimizer=optimizer, + lr_scheduler=ReduceLROnPlateau(optimizer, factor=0.5, patience=10), + loss_criterion=DiceLoss(), + loaders=loaders, + checkpoint_dir=os.path.join(PLANTSEG_MODELS_DIR, model_name), + max_num_epochs=10000, + max_num_iterations=100000, + device=device + ) + + return trainer.train() + + +def create_datasets(dataset_dir, phase, patch_shape): + assert phase in ['train', 'val'], f'Phase {phase} not supported' + phase_dir = os.path.join(dataset_dir, phase) + file_paths = traverse_h5_paths(phase_dir) + return [HDF5Dataset(file_path=file_path, augmenter=Augmenter(), patch_shape=patch_shape) for file_path in + file_paths] + + +def traverse_h5_paths(file_paths): + assert isinstance(file_paths, list) + results = [] + for file_path in file_paths: + if os.path.isdir(file_path): + # if file path is a directory take all H5 files in that directory + iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']] + for fp in chain(*iters): + results.append(fp) + else: + results.append(file_path) + return results diff --git a/plantseg/training/trainer.py b/plantseg/training/trainer.py index 78a9ba74..6b7be59a 100644 --- a/plantseg/training/trainer.py +++ b/plantseg/training/trainer.py @@ -13,6 +13,7 @@ from plantseg.training.utils import RunningAverage +# copied from https://github.com/wolny/pytorch-3dunet class UNetTrainer: """UNet trainer. diff --git a/plantseg/viewer/widget/training.py b/plantseg/viewer/widget/training.py index e3125ff1..4d80ca43 100644 --- a/plantseg/viewer/widget/training.py +++ b/plantseg/viewer/widget/training.py @@ -12,12 +12,14 @@ from plantseg.viewer.widget.utils import create_layer_name, start_threading_process, return_value_if_widget -def unet_training_wrapper(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs): +def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, patch_size, dimensionality, sparse, + device, **kwargs): """ Wrapper to run unet_training in a thread_worker, this is needed to allow the user to select the device in the headless mode. """ - return unet_training(dataset_dir, model_name, patch_size, dimensionality, sparse, device, **kwargs) + return unet_training(dataset_dir, model_name, in_channels, out_channels, patch_size, dimensionality, sparse, device, + **kwargs) @magicgui(call_button='Run Training', @@ -26,10 +28,16 @@ def unet_training_wrapper(dataset_dir, model_name, patch_size, dimensionality, s 'tooltip': 'Select a directory containing train and val subfolders'}, model_name={'label': 'Trained model name', 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + in_channels={'label': 'Input channels', + 'tooltip': 'Number of input channels', }, + out_channels={'label': 'Output channels', + 'tooltip': 'Number of output channels', }, dimensionality={'label': 'Dimensionality', 'tooltip': 'Dimensionality of the data (2D or 3D). ', 'widget_type': 'ComboBox', 'choices': list_all_dimensionality()}, + patch_size={'label': 'Patch size', + 'tooltip': 'Patch size use to processed the data.'}, sparse={'label': 'Sparse', 'tooltip': 'If True, SPOCO spare training algorithm will be used', 'widget_type': 'CheckBox'}, @@ -38,6 +46,8 @@ def unet_training_wrapper(dataset_dir, model_name, patch_size, dimensionality, s ) def widget_unet_training(dataset_dir: Path = Path.home(), model_name: str = 'my-model', + in_channels: int = 1, + out_channels: int = 1, dimensionality: str = '3D', patch_size: Tuple[int, int, int] = (80, 160, 160), sparse: bool = False, @@ -48,6 +58,8 @@ def widget_unet_training(dataset_dir: Path = Path.home(), runtime_kwargs={ 'dataset_dir': dataset_dir, 'model_name': model_name, + 'in_channels': in_channels, + 'out_channels': out_channels, 'patch_size': patch_size, 'dimensionality': dimensionality, 'sparse': sparse, @@ -72,3 +84,12 @@ def _on_dimensionality_changed(dimensionality: str): patch_size = (80, 160, 160) widget_unet_training.patch_size.value = patch_size + + +@widget_unet_training.sparse.changed.connect +def _on_sparse_change(sparse: bool): + sparse = return_value_if_widget(sparse) + if sparse: + widget_unet_training.out_channels.value = 8 + else: + widget_unet_training.out_channels.value = 1 From 989e59a343b63125ba71faa0fc63ecd0ee2495ab Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Tue, 20 Jun 2023 14:17:39 +0200 Subject: [PATCH 07/22] fix widget-training interface --- plantseg/resources/config_train_template.yaml | 82 +++++++++++++ plantseg/training/augs.py | 4 +- plantseg/training/train.py | 111 +++++++++++++----- plantseg/training/trainer.py | 28 +++-- plantseg/viewer/widget/training.py | 23 ++-- 5 files changed, 199 insertions(+), 49 deletions(-) create mode 100644 plantseg/resources/config_train_template.yaml diff --git a/plantseg/resources/config_train_template.yaml b/plantseg/resources/config_train_template.yaml new file mode 100644 index 00000000..fbc996cb --- /dev/null +++ b/plantseg/resources/config_train_template.yaml @@ -0,0 +1,82 @@ +model: + name: UNet3D + # number of input channels to the model + in_channels: 1 + # number of output channels + out_channels: 1 + # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm) + layer_order: gcr + # initial number of feature maps + f_maps: 32 + # number of groups in the groupnorm + num_groups: 8 + # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax + final_sigmoid: true +# loss function to be used during training +loss: + name: DiceLoss +optimizer: + # initial learning rate + learning_rate: 0.0001 + # weight decay + weight_decay: 0.00001 +# evaluation metric +lr_scheduler: + name: ReduceLROnPlateau + # make sure to use the 'min' mode cause lower AdaptedRandError is better + mode: min + factor: 0.2 + patience: 10 +trainer: + checkpoint_dir: CHECKPOINT_DIR + # path to the best_checkpoint.pytorch; to be used for fine-tuning the model with additional ground truth + pre_trained: null + # how many iterations between tensorboard logging + log_after_iters: 500 + # max number of epochs + max_num_epochs: 5000 + # max number of iterations + max_num_iterations: 50000 +# Configure training and validation loaders +loaders: + # how many subprocesses to use for data loading + num_workers: 8 + # path to the raw data within the H5 + raw_internal_path: raw + # path to the label data withtin the H5 + label_internal_path: label + # configuration of the train loader + train: + # path to the training datasets + file_paths: + - PATH_TO_TRAIN_DIR + + # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch + slice_builder: + name: FilterSliceBuilder + # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + patch_shape: [ 80, 160, 160 ] + # train stride between patches + stride_shape: [ 40, 80, 80 ] + # minimum volume of the labels in the patch + threshold: 0.1 + # probability of accepting patches which do not fulfil the threshold criterion + slack_acceptance: 0.01 + + # configuration of the val loader + val: + # path to the val datasets + file_paths: + - PATH_TO_VAL_DIR + + # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch + slice_builder: + name: FilterSliceBuilder + # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + patch_shape: [ 80, 160, 160 ] + # train stride between patches + stride_shape: [ 80, 160, 160 ] + # minimum volume of the labels in the patch + threshold: 0.1 + # probability of accepting patches which do not fulfil the threshold criterion + slack_acceptance: 0.01 \ No newline at end of file diff --git a/plantseg/training/augs.py b/plantseg/training/augs.py index 04e299be..b75ea1db 100644 --- a/plantseg/training/augs.py +++ b/plantseg/training/augs.py @@ -344,7 +344,7 @@ class Standardize: Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. """ - def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): + def __init__(self, mean=None, std=None, channelwise=False, eps=1e-10, **kwargs): if mean is not None or std is not None: assert mean is not None and std is not None self.mean = mean @@ -533,7 +533,7 @@ def __init__(self): def raw_transform(self, stats): return Compose([ - Standardize(stats['mean'], stats['std']), + Standardize(mean=stats['mean'], std=stats['std']), RandomFlip(np.random.RandomState(self.seed)), RandomRotate90(np.random.RandomState(self.seed)), RandomRotate(np.random.RandomState(self.seed), axes=[[2, 1]], angle_spectrum=45, mode='reflect'), diff --git a/plantseg/training/train.py b/plantseg/training/train.py index ffdad81f..190497ac 100644 --- a/plantseg/training/train.py +++ b/plantseg/training/train.py @@ -3,12 +3,17 @@ from itertools import chain from typing import Tuple +import numpy as np import torch +import yaml from PIL import Image +from torch import nn from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader, ConcatDataset +import matplotlib.pyplot as plt -from plantseg import PLANTSEG_MODELS_DIR +from plantseg import PLANTSEG_MODELS_DIR, plantseg_global_path +from plantseg.pipeline import gui_logger from plantseg.training.augs import Augmenter from plantseg.training.h5dataset import HDF5Dataset from plantseg.training.losses import DiceLoss @@ -16,65 +21,111 @@ from plantseg.training.trainer import UNetTrainer +def create_model_config(checkpoint_dir, in_channels, out_channels, patch_size, dimensionality, sparse, f_maps): + os.makedirs(checkpoint_dir, exist_ok=True) + train_template_path = os.path.join(plantseg_global_path, + "resources", + "config_train_template.yaml") + with open(train_template_path, 'r') as f: + train_template = yaml.load(f, Loader=yaml.FullLoader) + + train_template['model']['in_channels'] = in_channels + train_template['model']['out_channels'] = out_channels + train_template['model']['f_maps'] = f_maps + if dimensionality == '2D': + train_template['model']['name'] = 'UNet2D' + else: + train_template['model']['name'] = 'UNet3D' + train_template['model']['final_sigmoid'] = not sparse + train_template['trainer']['checkpoint_dir'] = checkpoint_dir + train_template['loaders']['train']['slice_builder']['patch_shape'] = patch_size + train_template['loaders']['val']['slice_builder']['patch_shape'] = patch_size + + out_path = os.path.join(checkpoint_dir, 'config_train.yml') + with open(out_path, 'w') as yaml_file: + yaml.dump(train_template, yaml_file, default_flow_style=False) + + +def plot_curves(learning_curves, checkpoint_dir): + plt.figure() + plt.plot(list(learning_curves['train_loss'].keys()), list(learning_curves['train_loss'].values()), + label='train_loss', c='y', marker='o') + plt.plot(list(learning_curves['val_loss'].keys()), list(learning_curves['val_loss'].values()), + label='val_loss', c='b', marker='o') + plt.legend(loc='upper right') + plot_path = os.path.join(checkpoint_dir, 'learning_curves.png') + plt.savefig(plot_path) + return np.array(Image.open(plot_path)) + + def unet_training(dataset_dir: str, model_name: str, in_channels: int, out_channels: int, - patch_size: Tuple[int, int, int], dimensionality: str, - sparse: bool, device: str, **kwargs) -> Image: + patch_size: Tuple[int, int, int], max_num_iters: int, dimensionality: str, + sparse: bool, device: str, headless: bool = False, **kwargs) -> Image: + # create model + batch_size = 1 + # set final activation to sigmoid if not sparse (i.e. not embedding model) + final_sigmoid = not sparse + f_maps = [32, 64, 128, 256, 512] + if dimensionality == '2D': + model = UNet2D(in_channels=in_channels, out_channels=out_channels, f_maps=f_maps, final_sigmoid=final_sigmoid) + else: + model = UNet3D(in_channels=in_channels, out_channels=out_channels, f_maps=f_maps, final_sigmoid=final_sigmoid) + + if torch.cuda.device_count() > 1 and device != 'cpu' and headless: + model = nn.DataParallel(model) + gui_logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction.') + batch_size *= torch.cuda.device_count() + device = 'cuda' + + model = model.to(device) # create loaders train_datasets = create_datasets(dataset_dir, 'train', patch_size) val_datasets = create_datasets(dataset_dir, 'val', patch_size) loaders = { - 'train': DataLoader(ConcatDataset(train_datasets), batch_size=1, shuffle=True, pin_memory=True, + 'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4), # don't shuffle during validation: useful when showing how predictions for a given batch get better over time - 'val': DataLoader(ConcatDataset(val_datasets), batch_size=1, shuffle=False, pin_memory=True, + 'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1) } - # create model - # set final activation to sigmoid if not sparse (i.e. not embedding model) - final_sigmoid = not sparse - if dimensionality == '2D': - model = UNet2D(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid) - else: - model = UNet3D(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid) - model = model.to(device) - # create optimizer optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) # create trainer + home_path = os.path.expanduser("~") + checkpoint_dir = os.path.join(home_path, PLANTSEG_MODELS_DIR, model_name) + if os.path.exists(checkpoint_dir): + gui_logger.warn(f'Checkpoint dir {checkpoint_dir} already exists! Overwriting...') + create_model_config(checkpoint_dir, in_channels, out_channels, patch_size, dimensionality, sparse, f_maps) + trainer = UNetTrainer( model=model, optimizer=optimizer, - lr_scheduler=ReduceLROnPlateau(optimizer, factor=0.5, patience=10), + lr_scheduler=ReduceLROnPlateau(optimizer, factor=0.2, patience=10), loss_criterion=DiceLoss(), loaders=loaders, - checkpoint_dir=os.path.join(PLANTSEG_MODELS_DIR, model_name), - max_num_epochs=10000, - max_num_iterations=100000, + checkpoint_dir=checkpoint_dir, + max_num_iterations=max_num_iters, device=device ) - return trainer.train() + learning_curves = trainer.train() + return plot_curves(learning_curves, checkpoint_dir) def create_datasets(dataset_dir, phase, patch_shape): assert phase in ['train', 'val'], f'Phase {phase} not supported' phase_dir = os.path.join(dataset_dir, phase) - file_paths = traverse_h5_paths(phase_dir) + file_paths = find_h5_files(phase_dir) return [HDF5Dataset(file_path=file_path, augmenter=Augmenter(), patch_shape=patch_shape) for file_path in file_paths] -def traverse_h5_paths(file_paths): - assert isinstance(file_paths, list) +def find_h5_files(data_dir): + assert os.path.isdir(data_dir), f'Not a directory {data_dir}' results = [] - for file_path in file_paths: - if os.path.isdir(file_path): - # if file path is a directory take all H5 files in that directory - iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']] - for fp in chain(*iters): - results.append(fp) - else: - results.append(file_path) + iters = [glob.glob(os.path.join(data_dir, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']] + for fp in chain(*iters): + results.append(fp) return results diff --git a/plantseg/training/trainer.py b/plantseg/training/trainer.py index 6b7be59a..408e3be6 100644 --- a/plantseg/training/trainer.py +++ b/plantseg/training/trainer.py @@ -20,20 +20,19 @@ class UNetTrainer: Args: model (Unet3D): UNet 3D model to be trained optimizer (nn.optim.Optimizer): optimizer used for training - lr_scheduler (torch.optim.lr_scheduler.LRScheduler): learning rate scheduler + lr_scheduler (torch.optim.lr_scheduler.ReduceLROnPlateau): learning rate scheduler loss_criterion (nn.Module): loss function loaders (dict): 'train' and 'val' loaders checkpoint_dir (string): dir for saving checkpoints and tensorboard logs - max_num_epochs (int): maximum number of epochs max_num_iterations (int): maximum number of iterations device (str): device to use for training log_after_iters (int): number of iterations before logging to tensorboard pre_trained(str): path to the pre-trained model """ - def __init__(self, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler.LRScheduler, - loss_criterion: nn.Module, loaders: dict, checkpoint_dir: str, max_num_epochs: int, - max_num_iterations: int, device: str = 'cuda', log_after_iters: int = 100, pre_trained: int = None): + def __init__(self, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler.ReduceLROnPlateau, + loss_criterion: nn.Module, loaders: dict, checkpoint_dir: str, + max_num_iterations: int, device: str = 'cuda', log_after_iters: int = 100, pre_trained: str = None): self.model = model self.optimizer = optimizer @@ -41,8 +40,8 @@ def __init__(self, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: o self.loss_criterion = loss_criterion self.loaders = loaders self.checkpoint_dir = checkpoint_dir - self.max_num_epochs = max_num_epochs self.max_num_iterations = max_num_iterations + self.max_num_epochs = max_num_iterations // len(loaders['train']) + 1 self.device = device self.log_after_iters = log_after_iters self.best_eval_loss = float('+inf') @@ -53,7 +52,13 @@ def __init__(self, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: o state = torch.load(pre_trained, map_location='cpu') self.model.load_state_dict(state) - def train(self): + # init learning curves + self.learning_curves = { + 'train_loss': {}, + 'val_loss': {}, + } + + def train(self) -> dict: for epoch in range(self.max_num_epochs): print(f'Epoch [{epoch}/{self.max_num_epochs}]') # train for one epoch @@ -61,7 +66,7 @@ def train(self): if should_terminate: gui_logger.info('Stopping criterion is satisfied. Finishing training') - return + return self.learning_curves print('Validating...') # set the model in eval mode @@ -69,6 +74,7 @@ def train(self): # evaluate on validation set eval_loss = self.validate() gui_logger.info(f'Val Loss: {eval_loss}.') + self.learning_curves['val_loss'][self.num_iterations] = eval_loss # set the model back to training mode self.model.train() @@ -86,6 +92,7 @@ def train(self): self._save_checkpoint(is_best) gui_logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") + return self.learning_curves def train_epoch(self): """Trains the model for 1 epoch. @@ -112,9 +119,10 @@ def train_epoch(self): if self.num_iterations % self.log_after_iters == 0: # log stats, params and images gui_logger.info(f'Train Loss: {train_losses.avg}.') + self.learning_curves['train_loss'][self.num_iterations] = train_losses.avg - if self.should_stop(): - return True + if self.should_stop(): + return True self.num_iterations += 1 diff --git a/plantseg/viewer/widget/training.py b/plantseg/viewer/widget/training.py index 4d80ca43..dcc24b9d 100644 --- a/plantseg/viewer/widget/training.py +++ b/plantseg/viewer/widget/training.py @@ -3,23 +3,25 @@ from typing import Tuple from magicgui import magicgui +from napari import Viewer from napari.types import LayerDataTuple from plantseg import PLANTSEG_MODELS_DIR from plantseg.training.train import unet_training from plantseg.utils import list_all_dimensionality from plantseg.viewer.widget.predictions import ALL_DEVICES -from plantseg.viewer.widget.utils import create_layer_name, start_threading_process, return_value_if_widget +from plantseg.viewer.widget.utils import create_layer_name, start_threading_process, return_value_if_widget, \ + layer_properties -def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, patch_size, dimensionality, sparse, - device, **kwargs): +def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, patch_size, max_num_iters, dimensionality, + sparse, device, **kwargs): """ Wrapper to run unet_training in a thread_worker, this is needed to allow the user to select the device in the headless mode. """ - return unet_training(dataset_dir, model_name, in_channels, out_channels, patch_size, dimensionality, sparse, device, - **kwargs) + return unet_training(dataset_dir, model_name, in_channels, out_channels, patch_size, max_num_iters, dimensionality, + sparse, device, **kwargs) @magicgui(call_button='Run Training', @@ -38,22 +40,27 @@ def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, pa 'choices': list_all_dimensionality()}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + max_num_iterations={'label': 'Max number of iterations'}, sparse={'label': 'Sparse', 'tooltip': 'If True, SPOCO spare training algorithm will be used', 'widget_type': 'CheckBox'}, device={'label': 'Device', 'choices': ALL_DEVICES} ) -def widget_unet_training(dataset_dir: Path = Path.home(), +def widget_unet_training(viewer: Viewer, + dataset_dir: Path = Path.home(), model_name: str = 'my-model', in_channels: int = 1, out_channels: int = 1, dimensionality: str = '3D', patch_size: Tuple[int, int, int] = (80, 160, 160), + max_num_iterations: int = 40000, sparse: bool = False, device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: out_name = create_layer_name(model_name, 'training') step_kwargs = dict(model_name=model_name, sparse=sparse, dimensionality=dimensionality) + layer_kwargs = layer_properties(name=out_name, + scale=[1, 1, 1]) return start_threading_process(unet_training_wrapper, runtime_kwargs={ 'dataset_dir': dataset_dir, @@ -61,6 +68,7 @@ def widget_unet_training(dataset_dir: Path = Path.home(), 'in_channels': in_channels, 'out_channels': out_channels, 'patch_size': patch_size, + 'max_num_iters': max_num_iterations, 'dimensionality': dimensionality, 'sparse': sparse, 'device': device @@ -69,8 +77,9 @@ def widget_unet_training(dataset_dir: Path = Path.home(), widgets_to_update=[], input_keys=(model_name, 'training'), out_name=out_name, - layer_kwarg={'name': out_name}, + layer_kwarg=layer_kwargs, layer_type='image', + viewer=viewer, statics_kwargs=step_kwargs ) From 0713530d202908391c9c91b7c09a78e66b7d0f2f Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Tue, 20 Jun 2023 15:26:07 +0200 Subject: [PATCH 08/22] minor change --- plantseg/viewer/widget/training.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/plantseg/viewer/widget/training.py b/plantseg/viewer/widget/training.py index dcc24b9d..f016b58e 100644 --- a/plantseg/viewer/widget/training.py +++ b/plantseg/viewer/widget/training.py @@ -59,8 +59,6 @@ def widget_unet_training(viewer: Viewer, device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: out_name = create_layer_name(model_name, 'training') step_kwargs = dict(model_name=model_name, sparse=sparse, dimensionality=dimensionality) - layer_kwargs = layer_properties(name=out_name, - scale=[1, 1, 1]) return start_threading_process(unet_training_wrapper, runtime_kwargs={ 'dataset_dir': dataset_dir, @@ -77,7 +75,7 @@ def widget_unet_training(viewer: Viewer, widgets_to_update=[], input_keys=(model_name, 'training'), out_name=out_name, - layer_kwarg=layer_kwargs, + layer_kwarg={'name': out_name, 'scale': None}, layer_type='image', viewer=viewer, statics_kwargs=step_kwargs From 8d4566fcbe34bc4fff8ffe13bf5ce2810c2d4a94 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Wed, 21 Jun 2023 11:41:48 +0200 Subject: [PATCH 09/22] minor --- plantseg/predictions/functional/array_predictor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/plantseg/predictions/functional/array_predictor.py b/plantseg/predictions/functional/array_predictor.py index f07b8950..fe1625a6 100644 --- a/plantseg/predictions/functional/array_predictor.py +++ b/plantseg/predictions/functional/array_predictor.py @@ -125,11 +125,8 @@ def __call__(self, test_dataset: Dataset) -> np.ndarray: else: out_channels = self.out_channels - if self.is_embedding: - # embeddings will be converted to affinities - prediction_maps_shape = (out_channels,) + volume_shape - else: - prediction_maps_shape = (out_channels,) + volume_shape + prediction_maps_shape = (out_channels,) + volume_shape + if self.verbose_logging: gui_logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}') gui_logger.info(f'Using patch_halo: {self.patch_halo}') From 1b8ba9a88e95037b33f41c24e7054aa04a0c12d8 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 11:59:07 +0200 Subject: [PATCH 10/22] restructure of default plantseg local model directory --- plantseg/__init__.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/plantseg/__init__.py b/plantseg/__init__.py index 8aa7aa69..7e11f6b3 100644 --- a/plantseg/__init__.py +++ b/plantseg/__init__.py @@ -1,26 +1,34 @@ -import os from pathlib import Path import yaml # Find the global path of plantseg -plantseg_global_path = Path(__file__).parent.absolute() +PLANTSEG_GLOBAL_PATH = Path(__file__).parent.absolute() # Create configs directory at startup -home_path = os.path.expanduser("~") -PLANTSEG_MODELS_DIR = ".plantseg_models" +USER_HOME_PATH = Path.home() -configs_path = os.path.join(home_path, PLANTSEG_MODELS_DIR, "configs") -os.makedirs(configs_path, exist_ok=True) +PLANTSEG_LOCAL_DIR = USER_HOME_PATH / '.plantseg' +PLANTSEG_MODELS_DIR = PLANTSEG_LOCAL_DIR / 'models' -# create custom zoo if does not exist -custom_zoo = os.path.join(home_path, PLANTSEG_MODELS_DIR, 'custom_zoo.yaml') +CONFIGS_PATH = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'configs' +CONFIGS_PATH.mkdir(parents=True, exist_ok=True) -if not os.path.exists(custom_zoo): - with open(custom_zoo, 'w') as f: +# create a user zoo config if does not exist +USER_MODEL_ZOO_CONFIG = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'user_model_zoo.yaml' + +if not USER_MODEL_ZOO_CONFIG.exists(): + with open(USER_MODEL_ZOO_CONFIG, 'w') as f: + yaml.dump({}, f) + +# create a custom datasets config if does not exist +USER_DATASETS_CONFIG = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'user_datasets.yaml' + +if not USER_DATASETS_CONFIG.exists(): + with open(USER_DATASETS_CONFIG, 'w') as f: yaml.dump({}, f) # Resources directory -RESOURCES_DIR = "resources" -model_zoo_path = os.path.join(plantseg_global_path, RESOURCES_DIR, "models_zoo.yaml") -standard_config_template = os.path.join(plantseg_global_path, RESOURCES_DIR, "config_gui_template.yaml") +RESOURCES_DIR = 'resources' +MODEL_ZOO_PATH = PLANTSEG_GLOBAL_PATH / RESOURCES_DIR / 'models_zoo.yaml' +STANDARD_CONFIG_TEMPLATE = PLANTSEG_GLOBAL_PATH / RESOURCES_DIR / 'config_gui_template.yaml' From 8a591a1679fcc85c8aa1585520ae6e227e58e66d Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 12:09:48 +0200 Subject: [PATCH 11/22] fix small bugs --- plantseg/__init__.py | 11 ++++++----- plantseg/legacy_gui/gui_tools.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/plantseg/__init__.py b/plantseg/__init__.py index 7e11f6b3..7384fc6c 100644 --- a/plantseg/__init__.py +++ b/plantseg/__init__.py @@ -10,19 +10,20 @@ PLANTSEG_LOCAL_DIR = USER_HOME_PATH / '.plantseg' PLANTSEG_MODELS_DIR = PLANTSEG_LOCAL_DIR / 'models' - -CONFIGS_PATH = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'configs' -CONFIGS_PATH.mkdir(parents=True, exist_ok=True) +PLANTSEG_MODELS_DIR.mkdir(parents=True, exist_ok=True) # create a user zoo config if does not exist -USER_MODEL_ZOO_CONFIG = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'user_model_zoo.yaml' +USER_MODEL_ZOO_CONFIG = PLANTSEG_MODELS_DIR / 'user_model_zoo.yaml' if not USER_MODEL_ZOO_CONFIG.exists(): with open(USER_MODEL_ZOO_CONFIG, 'w') as f: yaml.dump({}, f) +CONFIGS_PATH = PLANTSEG_LOCAL_DIR / 'configs' +CONFIGS_PATH.mkdir(parents=True, exist_ok=True) + # create a custom datasets config if does not exist -USER_DATASETS_CONFIG = USER_HOME_PATH / PLANTSEG_MODELS_DIR / 'user_datasets.yaml' +USER_DATASETS_CONFIG = PLANTSEG_LOCAL_DIR / 'user_datasets.yaml' if not USER_DATASETS_CONFIG.exists(): with open(USER_DATASETS_CONFIG, 'w') as f: diff --git a/plantseg/legacy_gui/gui_tools.py b/plantseg/legacy_gui/gui_tools.py index 3aa13335..520b348b 100644 --- a/plantseg/legacy_gui/gui_tools.py +++ b/plantseg/legacy_gui/gui_tools.py @@ -5,7 +5,7 @@ import yaml -from plantseg import custom_zoo, home_path, PLANTSEG_MODELS_DIR, model_zoo_path +from plantseg import USER_MODEL_ZOO_CONFIG, USER_HOME_PATH, PLANTSEG_MODELS_DIR, MODEL_ZOO_PATH from plantseg.__version__ import __version__ from plantseg.io import read_tiff_voxel_size, TIFF_EXTENSIONS from plantseg.legacy_gui import stick_all, stick_ew, var_to_tkinter, convert_rgb, PLANTSEG_GREEN @@ -482,7 +482,7 @@ def auto_rescale(self): factor from the resolution given by the user""" global current_model - model_config = load_config(model_zoo_path) + model_config = load_config(MODEL_ZOO_PATH) net_resolution = model_config[current_model]["resolution"] AutoResPopup(net_resolution, current_model, self.tk_value, self.font) @@ -561,7 +561,7 @@ def __init__(self, config): """ Browse for file and directory """ self.files = tkinter.StringVar() if config["path"] is None: - self.files.set(home_path) + self.files.set(USER_HOME_PATH) else: self.files.set(config["path"]) self.config = config @@ -569,6 +569,7 @@ def __init__(self, config): def browse_for_file(self): """ browse for file """ current_file_dir, _ = os.path.split(self.files.get()) + home_path = str(USER_HOME_PATH) current_file_dir = (home_path if len(home_path) > len(current_file_dir) else current_file_dir) @@ -584,6 +585,7 @@ def browse_for_file(self): def browse_for_directory(self): """ browse for directory """ current_file_dir, _ = os.path.split(self.files.get()) + home_path = str(USER_HOME_PATH) current_file_dir = (home_path if len(home_path) > len(current_file_dir) else current_file_dir) dire_name = filedialog.askdirectory(initialdir=current_file_dir, @@ -914,7 +916,7 @@ def remove_model(self, row=0, column=0): def delete_model(self): # Delete entry in zoo custom self.file_to_remove = self.file_to_remove.get() - custom_zoo_dict = load_config(custom_zoo) + custom_zoo_dict = load_config(USER_MODEL_ZOO_CONFIG) if custom_zoo_dict is None: custom_zoo_dict = {} @@ -927,10 +929,10 @@ def delete_model(self): self.popup.destroy() raise RuntimeError(msg) - with open(custom_zoo, 'w') as f: + with open(USER_MODEL_ZOO_CONFIG, 'w') as f: yaml.dump(custom_zoo_dict, f) - self.join = os.path.join(home_path, PLANTSEG_MODELS_DIR, self.file_to_remove) + self.join = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR, self.file_to_remove) file_directory = self.join if os.path.exists(file_directory): From 410753fa7fb34739b924a2262d5f4e27d46b1c1d Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 12:10:13 +0200 Subject: [PATCH 12/22] add template for dataset handling --- plantseg/legacy_gui/plantsegapp.py | 16 +++---- plantseg/pipeline/__init__.py | 4 +- plantseg/predictions/functional/utils.py | 6 +-- plantseg/run_plantseg.py | 6 +++ plantseg/training/train.py | 4 +- plantseg/utils.py | 18 ++++---- plantseg/viewer/containers.py | 12 ++++-- plantseg/viewer/training.py | 14 ++++++ plantseg/viewer/viewer.py | 5 ++- plantseg/viewer/widget/train_dataset.py | 54 ++++++++++++++++++++++++ 10 files changed, 109 insertions(+), 30 deletions(-) create mode 100644 plantseg/viewer/training.py create mode 100644 plantseg/viewer/widget/train_dataset.py diff --git a/plantseg/legacy_gui/plantsegapp.py b/plantseg/legacy_gui/plantsegapp.py index b8692707..55edd99e 100644 --- a/plantseg/legacy_gui/plantsegapp.py +++ b/plantseg/legacy_gui/plantsegapp.py @@ -9,7 +9,7 @@ import yaml from plantseg.utils import load_config -from plantseg import plantseg_global_path, configs_path, RESOURCES_DIR, standard_config_template +from plantseg import PLANTSEG_GLOBAL_PATH, CONFIGS_PATH, RESOURCES_DIR, STANDARD_CONFIG_TEMPLATE from plantseg.legacy_gui import convert_rgb from plantseg.legacy_gui.gui_tools import Files2Process, report_error, version_popup, LoadModelPopup, RemovePopup from plantseg.pipeline import gui_logger @@ -297,19 +297,19 @@ def display(self, record): @staticmethod def get_last_config_path(name="config_gui_last.yaml"): # Working directory path + relative dir structure to yaml file - config_path = os.path.join(configs_path, name) + config_path = os.path.join(CONFIGS_PATH, name) return config_path @staticmethod def get_app_config_path(name="gui_configuration.yaml"): # Working directory path + relative dir structure to yaml file - config_path = os.path.join(plantseg_global_path, RESOURCES_DIR, name) + config_path = os.path.join(PLANTSEG_GLOBAL_PATH, RESOURCES_DIR, name) return config_path @staticmethod def get_icon_path(name="FOR2581_Logo_FINAL_no_text.png"): # Working directory path + relative dir structure to yaml file - icon_path = os.path.join(plantseg_global_path, RESOURCES_DIR, name) + icon_path = os.path.join(PLANTSEG_GLOBAL_PATH, RESOURCES_DIR, name) return icon_path def load_config(self, name="config_gui_last.yaml"): @@ -320,7 +320,7 @@ def load_config(self, name="config_gui_last.yaml"): plantseg_config = load_config(plant_config_path) else: # Do not modify this location - plant_config_path = os.path.join(standard_config_template) + plant_config_path = os.path.join(STANDARD_CONFIG_TEMPLATE) plantseg_config = load_config(plant_config_path) return plant_config_path, plantseg_config @@ -333,7 +333,7 @@ def load_app_config(self, config="gui_configuration.yaml"): def reset_config(self): """ reset to default config, do not change path""" - plant_config_path = os.path.join(standard_config_template) + plant_config_path = os.path.join(STANDARD_CONFIG_TEMPLATE) self.plantseg_config = load_config(plant_config_path) (self.pre_proc_obj, @@ -343,7 +343,7 @@ def reset_config(self): def open_config(self): """ open new config""" - default_start = os.path.join(configs_path) + default_start = os.path.join(CONFIGS_PATH) os.makedirs(default_start, exist_ok=True) plant_config_path = tkinter.filedialog.askopenfilename(initialdir=default_start, title="Select file", @@ -359,7 +359,7 @@ def open_config(self): def save_config(self): """ save yaml from current entries in the legacy_gui""" self.update_config() - default_start = os.path.join(configs_path) + default_start = os.path.join(CONFIGS_PATH) os.makedirs(default_start, exist_ok=True) save_path = tkinter.filedialog.asksaveasfilename(initialdir=default_start, diff --git a/plantseg/pipeline/__init__.py b/plantseg/pipeline/__init__.py index c17cc37b..2caf851a 100644 --- a/plantseg/pipeline/__init__.py +++ b/plantseg/pipeline/__init__.py @@ -2,7 +2,7 @@ import os import sys -from plantseg import plantseg_global_path +from plantseg import PLANTSEG_GLOBAL_PATH gui_logger = logging.getLogger("PlantSeg") # hardcode the log-level for now @@ -17,4 +17,4 @@ # Resources directory RESOURCES_DIR = "resources" -raw2seg_config_template = os.path.join(plantseg_global_path, RESOURCES_DIR, "raw2seg_template.yaml") +raw2seg_config_template = os.path.join(PLANTSEG_GLOBAL_PATH, RESOURCES_DIR, "raw2seg_template.yaml") diff --git a/plantseg/predictions/functional/utils.py b/plantseg/predictions/functional/utils.py index 2ac5f1da..85100dc5 100644 --- a/plantseg/predictions/functional/utils.py +++ b/plantseg/predictions/functional/utils.py @@ -1,6 +1,6 @@ import os -from plantseg import plantseg_global_path, PLANTSEG_MODELS_DIR, home_path +from plantseg import PLANTSEG_GLOBAL_PATH, PLANTSEG_MODELS_DIR, USER_HOME_PATH from plantseg.augment.transforms import get_test_augmentations from plantseg.training.model import get_model from plantseg.pipeline import gui_logger @@ -11,7 +11,7 @@ def get_predict_template(): - predict_template_path = os.path.join(plantseg_global_path, + predict_template_path = os.path.join(PLANTSEG_GLOBAL_PATH, "resources", "config_predict_template.yaml") predict_template = load_config(predict_template_path) @@ -23,7 +23,7 @@ def get_model_config(model_name, model_update=False): config_train = get_train_config(model_name) model_config = config_train.pop('model') model = get_model(model_config) - model_path = os.path.join(home_path, + model_path = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR, model_name, "best_checkpoint.pytorch") diff --git a/plantseg/run_plantseg.py b/plantseg/run_plantseg.py index d20807e5..113a4cbd 100644 --- a/plantseg/run_plantseg.py +++ b/plantseg/run_plantseg.py @@ -6,6 +6,7 @@ def parser(): arg_parser.add_argument('--config', type=str, help='Path to the YAML config file', required=False) arg_parser.add_argument('--gui', action='store_true', help='Launch GUI configurator', required=False) arg_parser.add_argument('--napari', action='store_true', help='Napari Viewer', required=False) + arg_parser.add_argument('--training', action='store_true', help='Train a plantseg model', required=False) arg_parser.add_argument('--headless', type=str, help='Path to a .pkl workflow', required=False) arg_parser.add_argument('--version', action='store_true', help='PlantSeg version', required=False) arg_parser.add_argument('--clean', action='store_true', @@ -27,6 +28,10 @@ def main(): from plantseg.viewer.viewer import run_viewer run_viewer() + elif args.training: + from plantseg.viewer.training import run_training_headless + run_training_headless() + elif args.headless: from plantseg.viewer.headless import run_workflow_headless run_workflow_headless(args.headless) @@ -48,6 +53,7 @@ def main(): else: raise ValueError("Not enough arguments. Please use: \n" " --napari for launching the napari image viewer or \n" + " --training for launching the training configurator or \n" " --headless 'path_to_workflow.pkl' for launching a saved workflow or \n" " --gui for launching the graphical pipeline configurator or \n" " --config 'path_to_config.yaml' for launching the pipeline from command line or \n" diff --git a/plantseg/training/train.py b/plantseg/training/train.py index 190497ac..e3a18c24 100644 --- a/plantseg/training/train.py +++ b/plantseg/training/train.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader, ConcatDataset import matplotlib.pyplot as plt -from plantseg import PLANTSEG_MODELS_DIR, plantseg_global_path +from plantseg import PLANTSEG_MODELS_DIR, PLANTSEG_GLOBAL_PATH from plantseg.pipeline import gui_logger from plantseg.training.augs import Augmenter from plantseg.training.h5dataset import HDF5Dataset @@ -23,7 +23,7 @@ def create_model_config(checkpoint_dir, in_channels, out_channels, patch_size, dimensionality, sparse, f_maps): os.makedirs(checkpoint_dir, exist_ok=True) - train_template_path = os.path.join(plantseg_global_path, + train_template_path = os.path.join(PLANTSEG_GLOBAL_PATH, "resources", "config_train_template.yaml") with open(train_template_path, 'r') as f: diff --git a/plantseg/utils.py b/plantseg/utils.py index e73b902e..2247ea1a 100644 --- a/plantseg/utils.py +++ b/plantseg/utils.py @@ -9,7 +9,7 @@ import requests import yaml -from plantseg import model_zoo_path, custom_zoo, home_path, PLANTSEG_MODELS_DIR, plantseg_global_path +from plantseg import MODEL_ZOO_PATH, USER_MODEL_ZOO_CONFIG, USER_HOME_PATH, PLANTSEG_MODELS_DIR, PLANTSEG_GLOBAL_PATH from plantseg.__version__ import __version__ as current_version from plantseg.pipeline import gui_logger @@ -39,11 +39,11 @@ def get_model_zoo() -> dict: ... } """ - zoo_config = os.path.join(model_zoo_path) + zoo_config = os.path.join(MODEL_ZOO_PATH) zoo_config = load_config(zoo_config) - custom_zoo_config = load_config(custom_zoo) + custom_zoo_config = load_config(USER_MODEL_ZOO_CONFIG) if custom_zoo_config is None: custom_zoo_config = {} @@ -147,7 +147,7 @@ def add_custom_model(new_model_name: str, :return: """ - dest_dir = os.path.join(home_path, PLANTSEG_MODELS_DIR, new_model_name) + dest_dir = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR, new_model_name) os.makedirs(dest_dir, exist_ok=True) all_files = glob.glob(os.path.join(location, "*")) all_expected_files = ['config_train.yml', @@ -169,7 +169,7 @@ def add_custom_model(new_model_name: str, f'the model can not be loaded.' return False, msg - custom_zoo_dict = load_config(custom_zoo) + custom_zoo_dict = load_config(USER_MODEL_ZOO_CONFIG) if custom_zoo_dict is None: custom_zoo_dict = {} @@ -182,7 +182,7 @@ def add_custom_model(new_model_name: str, custom_zoo_dict[new_model_name]["modality"] = modality custom_zoo_dict[new_model_name]["output_type"] = output_type - with open(custom_zoo, 'w') as f: + with open(USER_MODEL_ZOO_CONFIG, 'w') as f: yaml.dump(custom_zoo_dict, f) return True, None @@ -200,7 +200,7 @@ def get_train_config(model_name: str) -> dict: """ check_models(model_name, config_only=True) # Load train config and add missing info - train_config_path = os.path.join(home_path, + train_config_path = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR, model_name, CONFIG_TRAIN_YAML) @@ -260,7 +260,7 @@ def check_models(model_name: str, update_files: bool = False, config_only: bool update_files): # Read config - model_file = os.path.join(plantseg_global_path, "resources", "models_zoo.yaml") + model_file = os.path.join(PLANTSEG_GLOBAL_PATH, "resources", "models_zoo.yaml") config = load_config(model_file) if model_name in config: @@ -282,7 +282,7 @@ def clean_models(): "make sure to copy all custom models you want to preserve before continuing.\n" "Are you sure you want to continue? (y/n) ") if answer == 'y': - ps_models_dir = os.path.join(home_path, PLANTSEG_MODELS_DIR) + ps_models_dir = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR) shutil.rmtree(ps_models_dir) print("All models deleted... PlantSeg will now close") return None diff --git a/plantseg/viewer/containers.py b/plantseg/viewer/containers.py index c8169ae2..ae0950a1 100644 --- a/plantseg/viewer/containers.py +++ b/plantseg/viewer/containers.py @@ -15,7 +15,8 @@ from plantseg.viewer.widget.segmentation import widget_fix_over_under_segmentation_from_nuclei from plantseg.viewer.widget.segmentation import widget_lifted_multicut from plantseg.viewer.widget.segmentation import widget_simple_dt_ws -from plantseg.viewer.widget.training import widget_unet_training +from plantseg.viewer.widget.train_dataset import widget_create_dataset, widget_print_dataset +from plantseg.viewer.widget.train_dataset import widget_add_stack, widget_delete_dataset def setup_menu(container, path=None): @@ -63,10 +64,13 @@ def get_gasp_workflow(): return container -def get_training_workflow(): - container = MainWindow(widgets=[widget_unet_training], +def get_dataset_workflow(): + container = MainWindow(widgets=[widget_create_dataset, + widget_print_dataset, + widget_add_stack, + widget_delete_dataset], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/UNet-Training') + container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Dataset-Managment') return container diff --git a/plantseg/viewer/training.py b/plantseg/viewer/training.py new file mode 100644 index 00000000..bded2c09 --- /dev/null +++ b/plantseg/viewer/training.py @@ -0,0 +1,14 @@ +import multiprocessing + +from plantseg.viewer.widget.predictions import ALL_DEVICES, ALL_CUDA_DEVICES +from plantseg.viewer.widget.training import widget_unet_training + +all_gpus_str = f'all gpus: {len(ALL_CUDA_DEVICES)}' +ALL_GPUS = [all_gpus_str] if len(ALL_CUDA_DEVICES) > 0 else [] +ALL_DEVICES_HEADLESS = ALL_DEVICES + ALL_GPUS + +MAX_WORKERS = len(ALL_CUDA_DEVICES) if len(ALL_CUDA_DEVICES) > 0 else multiprocessing.cpu_count() + + +def run_training_headless(): + widget_unet_training.show(run=True) diff --git a/plantseg/viewer/viewer.py b/plantseg/viewer/viewer.py index 325dce5a..6c882825 100644 --- a/plantseg/viewer/viewer.py +++ b/plantseg/viewer/viewer.py @@ -1,7 +1,8 @@ import napari -from plantseg.viewer.containers import get_extra_seg, get_extra_pred, get_training_workflow +from plantseg.viewer.containers import get_extra_seg, get_extra_pred from plantseg.viewer.containers import get_gasp_workflow, get_preprocessing_workflow, get_main +from plantseg.viewer.containers import get_dataset_workflow from plantseg.viewer.logging import napari_formatted_logging from plantseg.viewer.widget.proofreading.proofreading import setup_proofreading_keybindings @@ -15,7 +16,7 @@ def run_viewer(): for _containers, name in [(get_preprocessing_workflow(), 'Data - Processing'), (get_gasp_workflow(), 'UNet + Segmentation'), - (get_training_workflow(), 'Training'), + (get_dataset_workflow(), 'Dataset'), (get_extra_pred(), 'Extra-Pred'), (get_extra_seg(), 'Extra-Seg'), ]: diff --git a/plantseg/viewer/widget/train_dataset.py b/plantseg/viewer/widget/train_dataset.py new file mode 100644 index 00000000..69135353 --- /dev/null +++ b/plantseg/viewer/widget/train_dataset.py @@ -0,0 +1,54 @@ +from pathlib import Path + +from magicgui import magicgui +from napari import Viewer + +from plantseg import PLANTSEG_MODELS_DIR + + +@magicgui(call_button='Initialize Dataset', + name={'label': 'Dataset name', + 'tooltip': f'Initialize an empty dataset with name model_name'}, + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing where the dataset will be created, ' + '{dataset_dir}/model_name/.'} + ) +def widget_create_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): + dataset_dir = dataset_dir / name + + + dataset_dir.mkdir(parents=True, exist_ok=True) + + +@magicgui(call_button='Create Dataset', + name={'label': 'Dataset name', + 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing train and val subfolders'}, + ) +def widget_print_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): + pass + + +@magicgui(call_button='Create Dataset', + name={'label': 'Dataset name', + 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing train and val subfolders'}, + ) +def widget_add_stack(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): + pass + + +@magicgui(call_button='Delete Dataset', + name={'label': 'Dataset name', + 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing train and val subfolders'}, + ) +def widget_delete_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): + pass From bfe5a47724f473e0f3bb878f87a24ac442353c13 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 16:33:42 +0200 Subject: [PATCH 13/22] fix minor bug in add new custom model --- plantseg/__init__.py | 6 +- plantseg/utils.py | 135 +++++++++++++++++--------- plantseg/viewer/widget/predictions.py | 6 ++ 3 files changed, 96 insertions(+), 51 deletions(-) diff --git a/plantseg/__init__.py b/plantseg/__init__.py index 7384fc6c..2ca5d3d9 100644 --- a/plantseg/__init__.py +++ b/plantseg/__init__.py @@ -30,6 +30,6 @@ yaml.dump({}, f) # Resources directory -RESOURCES_DIR = 'resources' -MODEL_ZOO_PATH = PLANTSEG_GLOBAL_PATH / RESOURCES_DIR / 'models_zoo.yaml' -STANDARD_CONFIG_TEMPLATE = PLANTSEG_GLOBAL_PATH / RESOURCES_DIR / 'config_gui_template.yaml' +RESOURCES_DIR = PLANTSEG_GLOBAL_PATH / 'resources' +MODEL_ZOO_PATH = RESOURCES_DIR / 'models_zoo.yaml' +STANDARD_CONFIG_TEMPLATE = RESOURCES_DIR / 'config_gui_template.yaml' diff --git a/plantseg/utils.py b/plantseg/utils.py index 2247ea1a..c5301150 100644 --- a/plantseg/utils.py +++ b/plantseg/utils.py @@ -1,15 +1,14 @@ -import glob -import os import shutil from pathlib import Path from shutil import copy2 -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from warnings import warn import requests import yaml -from plantseg import MODEL_ZOO_PATH, USER_MODEL_ZOO_CONFIG, USER_HOME_PATH, PLANTSEG_MODELS_DIR, PLANTSEG_GLOBAL_PATH +from plantseg import MODEL_ZOO_PATH, USER_MODEL_ZOO_CONFIG, PLANTSEG_MODELS_DIR, PLANTSEG_LOCAL_DIR +from plantseg import USER_DATASETS_CONFIG from plantseg.__version__ import __version__ as current_version from plantseg.pipeline import gui_logger @@ -17,7 +16,7 @@ BEST_MODEL_PYTORCH = "best_checkpoint.pytorch" -def load_config(config_path: str) -> dict: +def load_config(config_path: Union[str, Path]) -> dict: """ load a yaml config in a dictionary """ @@ -39,9 +38,7 @@ def get_model_zoo() -> dict: ... } """ - zoo_config = os.path.join(MODEL_ZOO_PATH) - - zoo_config = load_config(zoo_config) + zoo_config = load_config(MODEL_ZOO_PATH) custom_zoo_config = load_config(USER_MODEL_ZOO_CONFIG) @@ -147,40 +144,43 @@ def add_custom_model(new_model_name: str, :return: """ - dest_dir = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR, new_model_name) - os.makedirs(dest_dir, exist_ok=True) - all_files = glob.glob(os.path.join(location, "*")) - all_expected_files = ['config_train.yml', - 'last_checkpoint.pytorch', - 'best_checkpoint.pytorch'] + # check if all the required files are present + location = Path(location) + all_files = location.glob('*') + all_expected_files = [CONFIG_TRAIN_YAML, BEST_MODEL_PYTORCH] - recommended_patch_size = [80, 170, 170] + to_copy = [] for file in all_files: - if os.path.basename(file) == 'config_train.yaml': - config_train = load_config(file) - recommended_patch_size = list(config_train['loaders']['train']['slice_builder']['patch_shape']) - - if os.path.basename(file) in all_expected_files: - copy2(file, dest_dir) - all_expected_files.remove(os.path.basename(file)) + if file.name in all_expected_files: + to_copy.append(file) - if len(all_expected_files) != 0: + if len(to_copy) != len(all_expected_files): msg = f'It was not possible to find in the directory specified {all_expected_files}, ' \ f'the model can not be loaded.' return False, msg + # copy model files to the model zoo + dest_dir = PLANTSEG_MODELS_DIR / new_model_name + dest_dir.mkdir(parents=True, exist_ok=True) + for file in to_copy: + copy2(file, dest_dir) + + _config = load_config(location / CONFIG_TRAIN_YAML) + recommended_patch_size = list(_config['loaders']['train']['slice_builder']['patch_shape']) + + new_model_dict = {'path': str(location), + 'resolution': resolution, + 'description': description, + 'recommended_patch_size': recommended_patch_size, + 'dimensionality': dimensionality, + 'modality': modality, + 'output_type': output_type} + + # add model to the user model zoo custom_zoo_dict = load_config(USER_MODEL_ZOO_CONFIG) if custom_zoo_dict is None: custom_zoo_dict = {} - - custom_zoo_dict[new_model_name] = {} - custom_zoo_dict[new_model_name]["path"] = str(location) - custom_zoo_dict[new_model_name]["resolution"] = resolution - custom_zoo_dict[new_model_name]["description"] = description - custom_zoo_dict[new_model_name]["recommended_patch_size"] = recommended_patch_size - custom_zoo_dict[new_model_name]["dimensionality"] = dimensionality - custom_zoo_dict[new_model_name]["modality"] = modality - custom_zoo_dict[new_model_name]["output_type"] = output_type + custom_zoo_dict[new_model_name] = new_model_dict with open(USER_MODEL_ZOO_CONFIG, 'w') as f: yaml.dump(custom_zoo_dict, f) @@ -200,10 +200,7 @@ def get_train_config(model_name: str) -> dict: """ check_models(model_name, config_only=True) # Load train config and add missing info - train_config_path = os.path.join(USER_HOME_PATH, - PLANTSEG_MODELS_DIR, - model_name, - CONFIG_TRAIN_YAML) + train_config_path = PLANTSEG_MODELS_DIR / model_name / CONFIG_TRAIN_YAML config_train = load_config(train_config_path) return config_train @@ -229,9 +226,10 @@ def download_model_config(model_url: str, out_dir: str) -> None: def download_files(urls: dict, out_dir: str) -> None: + out_dir = Path(out_dir) for filename, url in urls.items(): with requests.get(url, allow_redirects=True) as r: - with open(os.path.join(out_dir, filename), 'wb') as f: + with open(out_dir / filename, 'wb') as f: f.write(r.content) @@ -242,17 +240,19 @@ def check_models(model_name: str, update_files: bool = False, config_only: bool :param update_files: if true force the re-download of the model :param config_only: if true only downloads the config file and skips the model file """ + assert isinstance(model_name, str), "model_name must be a string" + + if Path(model_name).is_dir(): + model_dir = Path(model_name) - if os.path.isdir(model_name): - model_dir = model_name else: - model_dir = os.path.join(os.path.expanduser("~"), PLANTSEG_MODELS_DIR, model_name) + model_dir = PLANTSEG_MODELS_DIR / model_name # Check if model directory exist if not create it - if ~os.path.exists(model_dir): - os.makedirs(model_dir, exist_ok=True) + if not model_dir.exists(): + model_dir.mkdir(parents=True, exist_ok=True) - model_config_path = os.path.exists(os.path.join(model_dir, CONFIG_TRAIN_YAML)) - model_best_path = os.path.exists(os.path.join(model_dir, BEST_MODEL_PYTORCH)) + model_config_path = (model_dir / CONFIG_TRAIN_YAML).exists() + model_best_path = (model_dir / BEST_MODEL_PYTORCH).exists() # Check if files are there, if not download them if (not model_config_path or @@ -260,8 +260,7 @@ def check_models(model_name: str, update_files: bool = False, config_only: bool update_files): # Read config - model_file = os.path.join(PLANTSEG_GLOBAL_PATH, "resources", "models_zoo.yaml") - config = load_config(model_file) + config = load_config(MODEL_ZOO_PATH) if model_name in config: model_url = config[model_name]["model_url"] @@ -282,8 +281,7 @@ def clean_models(): "make sure to copy all custom models you want to preserve before continuing.\n" "Are you sure you want to continue? (y/n) ") if answer == 'y': - ps_models_dir = os.path.join(USER_HOME_PATH, PLANTSEG_MODELS_DIR) - shutil.rmtree(ps_models_dir) + shutil.rmtree(PLANTSEG_LOCAL_DIR) print("All models deleted... PlantSeg will now close") return None @@ -326,3 +324,44 @@ def check_version(plantseg_url=' https://api.github.com/repos/hci-unihd/plant-se print(f"New version of PlantSeg available: {latest_version}.\n" f"Please update your version to the latest one!") return None + + +def list_datasets(): + """ + List all available datasets created by the user + """ + datasets = load_config(USER_DATASETS_CONFIG) + return list(datasets.keys()) + + +def get_dataset(key: str): + """ + Get a dataset from the user dataset config file + """ + datasets = load_config(USER_DATASETS_CONFIG) + if key not in datasets: + raise ValueError(f"Dataset {key} not found. Please check the spelling. Available datasets: {list_datasets()}") + return datasets[key] + + +def save_dataset(key: str, dataset: dict): + """ + Save a dataset to the user dataset config file, if the dataset already exists it will be overwritten + """ + datasets = load_config(USER_DATASETS_CONFIG) + datasets[key] = dataset + + with open(USER_DATASETS_CONFIG, 'w') as f: + yaml.dump(datasets, f) + + +def delete_dataset(key: str): + """ + Delete a dataset from the user dataset config file + """ + datasets = load_config(USER_DATASETS_CONFIG) + if key not in datasets: + raise ValueError(f"Dataset {key} not found. Please check the spelling. Available datasets: {list_datasets()}") + del datasets[key] + with open(USER_DATASETS_CONFIG, 'w') as f: + yaml.dump(datasets, f) diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/viewer/widget/predictions.py index d5dbdea1..b1a73c2c 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/viewer/widget/predictions.py @@ -319,3 +319,9 @@ def widget_add_custom_model(new_model_name: str = 'custom_model', f'{error_msg}', level='error', thread='Add Custom Model') + + +@widget_add_custom_model.called.connect +def _on_add_custom_model_called(): + widget_unet_predictions.model_name.choices = list_models() + widget_iterative_unet_predictions.model_name.choices = list_models() From 52d1509ba0208ad9cf109a8cb781aea06d3957c6 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 17:30:28 +0200 Subject: [PATCH 14/22] finalize first draft of the dataset tools --- plantseg/viewer/widget/dataset_handling.py | 201 +++++++++++++++++++++ plantseg/viewer/widget/train_dataset.py | 54 ------ 2 files changed, 201 insertions(+), 54 deletions(-) create mode 100644 plantseg/viewer/widget/dataset_handling.py delete mode 100644 plantseg/viewer/widget/train_dataset.py diff --git a/plantseg/viewer/widget/dataset_handling.py b/plantseg/viewer/widget/dataset_handling.py new file mode 100644 index 00000000..1560977f --- /dev/null +++ b/plantseg/viewer/widget/dataset_handling.py @@ -0,0 +1,201 @@ +from pathlib import Path + +from magicgui import magicgui +from napari.layers import Labels, Image + +from plantseg import PLANTSEG_MODELS_DIR +from plantseg.io import create_h5 +from plantseg.utils import list_datasets, save_dataset, get_dataset, delete_dataset +from plantseg.viewer.logging import napari_formatted_logging + +empty_dataset = ['none'] +startup_list_datasets = list_datasets() or empty_dataset + + +@magicgui(call_button='Initialize Dataset', + dataset_name={'label': 'Dataset name', + 'tooltip': f'Initialize an empty dataset with name model_name'}, + dataset_dir={'label': 'Path to the dataset directory', + 'mode': 'd', + 'tooltip': 'Select a directory containing where the dataset will be created, ' + '{dataset_dir}/model_name/.'} + ) +def widget_create_dataset(dataset_name: str = 'my-dataset', dataset_dir: Path = Path.home()): + dataset_dir = dataset_dir / dataset_name + dataset_dir.mkdir(parents=True, exist_ok=True) + + new_dataset = {'name': dataset_name, + 'dataset_dir': str(dataset_dir), + 'task': None, + 'dimensionality': None, # 2D or 3D + 'image_channels': None, + 'image_key': 'raw', + 'labels_key': 'labels', + 'is_sparse': False, + 'train': [], + 'val': [], + 'test': [], + } + + if dataset_name not in list_datasets(): + save_dataset(dataset_name, new_dataset) + return new_dataset + + raise ValueError(f'Dataset {dataset_name} already exists.') + + +@magicgui(call_button='Create Dataset', + dataset_name={'label': 'Dataset name', + 'choices': startup_list_datasets, + 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + phase={'label': 'Phase', + 'choices': ['train', 'val', 'test'], + 'tooltip': f'Define if the stack will be used for training, validation or testing'}, + is_sparse={'label': 'Sparse dataset', + 'tooltip': 'If checked, the dataset will be saved in sparse format.'} + ) +def widget_add_stack(dataset_name: str = startup_list_datasets[0], + image: Image = None, + labels: Labels = None, + phase: str = 'train', + is_sparse: bool = False): + dataset_config = get_dataset(dataset_name) + + if image is None or labels is None: + napari_formatted_logging(message=f'To add a stack to the dataset, please select an image and a labels layer.', + thread='widget_add_stack', + level='warning') + return None + + if is_sparse: + # if a single dataset is sparse, all the others should be threaded as sparse + dataset_config['is_sparse'] = True + + image_data = image.data + labels_data = labels.data + + # Validation of the image and labels data + # check if the image and labels have the same shape, + # dimensionality and number of channels as the rest of the dataset + + if image_data.ndim == 3: + image_channels = 1 + dimensionality = '2D' if image_data.shape[0] == 1 else '3D' + assert image_data.shape == labels_data.shape, f'Image and labels should have the same shape, found ' \ + f'{image_data.shape} and {labels_data.shape}.' + + elif image_data.ndim == 4: + image_channels = image_data.shape[0] + dimensionality = '2D' if image_data.shape[1] == 1 else '3D' + assert image_data.shape[1:] == labels_data.shape, f'Image and labels should have the same shape, found ' \ + f'{image_data.shape} and {labels_data.shape}.' + + else: + raise ValueError(f'Image data should be 3D or multichannel 3D, found {image_data.ndim}D.') + + dataset_image_channels = dataset_config['image_channels'] + if dataset_image_channels is None: + dataset_config['image_channels'] = image_channels + elif dataset_image_channels != image_channels: + raise ValueError(f'Image data should have {dataset_image_channels} channels, found {image_channels}.') + + dataset_dimensionality = dataset_config['dimensionality'] + if dataset_dimensionality is None: + dataset_config['dimensionality'] = dimensionality + elif dataset_dimensionality != dimensionality: + raise ValueError(f'Image data should be {dataset_dimensionality}, found {dimensionality}.') + + if is_sparse: + dataset_config['is_sparse'] = True + + # Check if the stack name already exists in the dataset + # If so, add a number to the end of the name until it is unique + stack_name = image.name + existing_stacks = dataset_config[phase] + + idx = 0 + while True: + if stack_name in existing_stacks: + stack_name = f'{stack_name}_{idx}' + else: + break + idx += 1 + + dataset_config[phase].append(stack_name) + + # Save the data to disk + dataset_dir = Path(dataset_config['dataset_dir']) / phase + dataset_dir.mkdir(parents=True, exist_ok=True) + + image_path = str(dataset_dir / f'{stack_name}.h5') + create_h5(image_path, image_data, key=dataset_config['image_key']) + create_h5(image_path, labels_data, key=dataset_config['labels_key']) + save_dataset(dataset_name, dataset_config) + napari_formatted_logging(message=f'Stack {stack_name} added to dataset {dataset_name}.', + thread='widget_add_stack', + level='info') + + +@magicgui(call_button='Validata Dataset', + dataset_name={'label': 'Dataset name', + 'choices': startup_list_datasets, + 'tooltip': f'Name of the dataset to be validated'}, + ) +def widget_validata_dataset(dataset_name: str = startup_list_datasets[0]): + dataset_config = get_dataset(dataset_name) + + # check all stacks are present + dataset_dir = Path(dataset_config['dataset_dir']) + for phase in ['train', 'val', 'test']: + phase_dir = dataset_dir / phase + stacks_expected = dataset_config[phase] + stacks_found = [file.stem for file in phase_dir.glob('*.h5')] + if len(stacks_found) != len(stacks_expected): + napari_formatted_logging(message=f'Found {len(stacks_found)} stacks in {phase} phase, ' + f'expected {len(stacks_expected)}.', + thread='widget_validata_dataset', + level='warning') + + dataset_config[phase] = stacks_found + + # check all stacks have the same shape and dimensionality + for key, value in dataset_config.items(): + napari_formatted_logging(message=f'Dataset info {key}: {value}', + thread='widget_validata_dataset', + level='info') + + +@magicgui(call_button='Delete Dataset', + dataset_name={'label': 'Dataset name', + 'choices': startup_list_datasets, + 'tooltip': f'Name of the dataset to be deleted'}, + ) +def widget_delete_dataset(dataset_name: str = startup_list_datasets[0]): + delete_dataset(dataset_name) + + +@widget_create_dataset.called.connect +def _on_create_dataset_called(new_dataset: dict): + new_dataset_list = list_datasets() + if not widget_add_stack.visible: + widget_add_stack.show() + widget_add_stack.dataset_name.choices = new_dataset_list + widget_add_stack.dataset_name.value = new_dataset['name'] + + if not widget_delete_dataset.visible: + widget_delete_dataset.show() + + widget_delete_dataset.dataset_name.choices = new_dataset_list + widget_delete_dataset.dataset_name.value = new_dataset['name'] + + if not widget_validata_dataset.visible: + widget_validata_dataset.show() + + widget_validata_dataset.dataset_name.choices = new_dataset_list + widget_validata_dataset.dataset_name.value = new_dataset['name'] + + +if startup_list_datasets == empty_dataset: + widget_add_stack.hide() + widget_delete_dataset.hide() + widget_validata_dataset.hide() diff --git a/plantseg/viewer/widget/train_dataset.py b/plantseg/viewer/widget/train_dataset.py deleted file mode 100644 index 69135353..00000000 --- a/plantseg/viewer/widget/train_dataset.py +++ /dev/null @@ -1,54 +0,0 @@ -from pathlib import Path - -from magicgui import magicgui -from napari import Viewer - -from plantseg import PLANTSEG_MODELS_DIR - - -@magicgui(call_button='Initialize Dataset', - name={'label': 'Dataset name', - 'tooltip': f'Initialize an empty dataset with name model_name'}, - dataset_dir={'label': 'Path to the dataset directory', - 'mode': 'd', - 'tooltip': 'Select a directory containing where the dataset will be created, ' - '{dataset_dir}/model_name/.'} - ) -def widget_create_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): - dataset_dir = dataset_dir / name - - - dataset_dir.mkdir(parents=True, exist_ok=True) - - -@magicgui(call_button='Create Dataset', - name={'label': 'Dataset name', - 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, - dataset_dir={'label': 'Path to the dataset directory', - 'mode': 'd', - 'tooltip': 'Select a directory containing train and val subfolders'}, - ) -def widget_print_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): - pass - - -@magicgui(call_button='Create Dataset', - name={'label': 'Dataset name', - 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, - dataset_dir={'label': 'Path to the dataset directory', - 'mode': 'd', - 'tooltip': 'Select a directory containing train and val subfolders'}, - ) -def widget_add_stack(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): - pass - - -@magicgui(call_button='Delete Dataset', - name={'label': 'Dataset name', - 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, - dataset_dir={'label': 'Path to the dataset directory', - 'mode': 'd', - 'tooltip': 'Select a directory containing train and val subfolders'}, - ) -def widget_delete_dataset(viewer: Viewer, name: str = 'my-dataset', dataset_dir: Path = Path.home()): - pass From 156b1e82765caa98cfd509704b9b91c67e4c8978 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Wed, 9 Aug 2023 17:30:37 +0200 Subject: [PATCH 15/22] finalize first draft of the dataset tools --- plantseg/viewer/containers.py | 6 +++--- plantseg/viewer/training.py | 9 --------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/plantseg/viewer/containers.py b/plantseg/viewer/containers.py index ae0950a1..e82bc17e 100644 --- a/plantseg/viewer/containers.py +++ b/plantseg/viewer/containers.py @@ -15,8 +15,8 @@ from plantseg.viewer.widget.segmentation import widget_fix_over_under_segmentation_from_nuclei from plantseg.viewer.widget.segmentation import widget_lifted_multicut from plantseg.viewer.widget.segmentation import widget_simple_dt_ws -from plantseg.viewer.widget.train_dataset import widget_create_dataset, widget_print_dataset -from plantseg.viewer.widget.train_dataset import widget_add_stack, widget_delete_dataset +from plantseg.viewer.widget.dataset_handling import widget_create_dataset +from plantseg.viewer.widget.dataset_handling import widget_add_stack, widget_delete_dataset, widget_validata_dataset def setup_menu(container, path=None): @@ -66,8 +66,8 @@ def get_gasp_workflow(): def get_dataset_workflow(): container = MainWindow(widgets=[widget_create_dataset, - widget_print_dataset, widget_add_stack, + widget_validata_dataset, widget_delete_dataset], labels=False) container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Dataset-Managment') diff --git a/plantseg/viewer/training.py b/plantseg/viewer/training.py index bded2c09..005195fd 100644 --- a/plantseg/viewer/training.py +++ b/plantseg/viewer/training.py @@ -1,14 +1,5 @@ -import multiprocessing - -from plantseg.viewer.widget.predictions import ALL_DEVICES, ALL_CUDA_DEVICES from plantseg.viewer.widget.training import widget_unet_training -all_gpus_str = f'all gpus: {len(ALL_CUDA_DEVICES)}' -ALL_GPUS = [all_gpus_str] if len(ALL_CUDA_DEVICES) > 0 else [] -ALL_DEVICES_HEADLESS = ALL_DEVICES + ALL_GPUS - -MAX_WORKERS = len(ALL_CUDA_DEVICES) if len(ALL_CUDA_DEVICES) > 0 else multiprocessing.cpu_count() - def run_training_headless(): widget_unet_training.show(run=True) From db462871cc59f6248e04f9f749cae38ce4e51452 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Thu, 10 Aug 2023 23:52:16 +0200 Subject: [PATCH 16/22] finalize first draft of the dataset tools --- .../widget => dataset_tools}/__init__.py | 0 plantseg/dataset_tools/dataset_handler.py | 380 ++++++++++++++++++ plantseg/dataset_tools/images.py | 309 ++++++++++++++ plantseg/dataset_tools/validators.py | 75 ++++ plantseg/io/h5.py | 21 +- plantseg/legacy_gui/gui_tools.py | 16 +- plantseg/legacy_gui/plantsegapp.py | 4 +- plantseg/pipeline/utils.py | 2 +- plantseg/run_plantseg.py | 8 +- plantseg/{viewer => ui}/__init__.py | 0 plantseg/{viewer => ui}/containers.py | 28 +- plantseg/{viewer => ui}/dag_handler.py | 0 plantseg/{viewer => ui}/headless.py | 4 +- plantseg/{viewer => ui}/logging.py | 0 plantseg/{viewer => ui}/training.py | 2 +- plantseg/{viewer => ui}/viewer.py | 10 +- .../proofreading => ui/widgets}/__init__.py | 0 .../widget => ui/widgets}/dataprocessing.py | 10 +- .../widgets/dataset_tools.py} | 14 +- plantseg/{viewer/widget => ui/widgets}/io.py | 6 +- .../widget => ui/widgets}/predictions.py | 10 +- plantseg/ui/widgets/proofreading/__init__.py | 0 .../widgets}/proofreading/proofreading.py | 16 +- .../proofreading/split_merge_tools.py | 4 +- .../widgets}/proofreading/utils.py | 0 .../widget => ui/widgets}/segmentation.py | 8 +- .../{viewer/widget => ui/widgets}/training.py | 4 +- .../{viewer/widget => ui/widgets}/utils.py | 4 +- plantseg/utils.py | 13 +- 29 files changed, 858 insertions(+), 90 deletions(-) rename plantseg/{viewer/widget => dataset_tools}/__init__.py (100%) create mode 100644 plantseg/dataset_tools/dataset_handler.py create mode 100644 plantseg/dataset_tools/images.py create mode 100644 plantseg/dataset_tools/validators.py rename plantseg/{viewer => ui}/__init__.py (100%) rename plantseg/{viewer => ui}/containers.py (70%) rename plantseg/{viewer => ui}/dag_handler.py (100%) rename plantseg/{viewer => ui}/headless.py (96%) rename plantseg/{viewer => ui}/logging.py (100%) rename plantseg/{viewer => ui}/training.py (52%) rename plantseg/{viewer => ui}/viewer.py (69%) rename plantseg/{viewer/widget/proofreading => ui/widgets}/__init__.py (100%) rename plantseg/{viewer/widget => ui/widgets}/dataprocessing.py (97%) rename plantseg/{viewer/widget/dataset_handling.py => ui/widgets/dataset_tools.py} (95%) rename plantseg/{viewer/widget => ui/widgets}/io.py (98%) rename plantseg/{viewer/widget => ui/widgets}/predictions.py (97%) create mode 100644 plantseg/ui/widgets/proofreading/__init__.py rename plantseg/{viewer/widget => ui/widgets}/proofreading/proofreading.py (95%) rename plantseg/{viewer/widget => ui/widgets}/proofreading/split_merge_tools.py (95%) rename plantseg/{viewer/widget => ui/widgets}/proofreading/utils.py (100%) rename plantseg/{viewer/widget => ui/widgets}/segmentation.py (98%) rename plantseg/{viewer/widget => ui/widgets}/training.py (96%) rename plantseg/{viewer/widget => ui/widgets}/utils.py (96%) diff --git a/plantseg/viewer/widget/__init__.py b/plantseg/dataset_tools/__init__.py similarity index 100% rename from plantseg/viewer/widget/__init__.py rename to plantseg/dataset_tools/__init__.py diff --git a/plantseg/dataset_tools/dataset_handler.py b/plantseg/dataset_tools/dataset_handler.py new file mode 100644 index 00000000..e83cfd9d --- /dev/null +++ b/plantseg/dataset_tools/dataset_handler.py @@ -0,0 +1,380 @@ +from pathlib import Path +from shutil import rmtree +from typing import Union, Protocol +from warnings import warn + +from plantseg.dataset_tools.images import Stack +from plantseg.io.h5 import H5_EXTENSIONS +from plantseg.utils import dump_dataset_dict, get_dataset_dict, delist_dataset, list_datasets + + +class DatasetValidator(Protocol): + def __init__(self, *args, **kwargs): + ... + + def __call__(self, dataset: object) -> tuple[bool, str]: + ... + + +class StackValidator(Protocol): + def __init__(self, dataset: object): + ... + + def __call__(self, stack_path: Union[str, Path]) -> tuple[bool, str]: + ... + + +class ComposeDatasetValidators: + """ + Compose multiple dataset validators into a single one. + """ + success_msg = 'All tests passed.' + + def __init__(self, *validators: DatasetValidator): + self.validators = validators + + def __call__(self, dataset: object) -> tuple[bool, str]: + return self.apply(dataset) + + def apply(self, dataset: object) -> tuple[bool, str]: + """ + Apply all the validators to the dataset. + Args: + dataset: dataset to validate + Returns: + tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. + """ + for validator in self.validators: + valid, msg = validator(dataset) + if not valid: + return valid, msg + return True, self.success_msg + + def batch_apply(self, list_dataset: list[object]) -> tuple[bool, str]: + """ + Apply all the validators to a list of datasets. + Args: + list_dataset: list of datasets to validate + Returns: + tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. + """ + for dataset in list_dataset: + valid, msg = self.apply(dataset) + if not valid: + msg = f'Validation failed for {dataset}.\nWith msg: {msg}' + return False, msg + + return True, self.success_msg + + +class ComposeStackValidators: + """ + Compose multiple stack validators into a single one. + """ + success_msg = 'All tests passed.' + + def __init__(self, *validators: StackValidator): + self.validators = validators + + def __call__(self, stack_path: Union[str, Path]) -> tuple[bool, str]: + return self.apply(stack_path) + + def apply(self, stack_path: Union[str, Path]) -> tuple[bool, str]: + """ + Apply all the validators to the stack. + Args: + stack_path: path to the stack to validate + + Returns: + tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. + + """ + for validator in self.validators: + valid, msg = validator(stack_path) + if not valid: + return valid, msg + return True, self.success_msg + + def batch_apply(self, list_stack_path: list[Union[str, Path]]) -> tuple[bool, str]: + """ + Apply all the validators to a list of stacks. + Args: + list_stack_path: list of paths to the stacks to validate + + Returns: + tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. + + """ + for stack_path in list_stack_path: + valid, msg = self.apply(stack_path) + if not valid: + msg = f'Validation failed for {stack_path}.\nWith msg: {msg}' + return False, msg + + return True, self.success_msg + + +class DatasetHandler: + """ + DatasetHandler is a class that contains all the information about a dataset. + It is used to create a dataset from a directory, to save a dataset to a directory and manage the dataset. + """ + name: str + keys: tuple[str, ...] + default_phases: tuple[str, ...] + default_file_formats: tuple[str, ...] + default_phases: tuple[str, ...] = ('train', 'val', 'test') + train: list[str] + val: list[str] + test: list[str] + + _default_keys = {'task': None, + 'dimensionality': None, # 2D or 3D + 'image_channels': None, + 'keys': ('raw', 'labels'), # keys of the h5 file (raw, labels, etc. ) + 'is_sparse': False, + 'default_file_formats': H5_EXTENSIONS + } + + def __init__(self, name: str, dataset_dir: Union[str, Path], **kwargs): + self.name = name + self.dataset_dir = Path(dataset_dir) + self.dataset_dir.mkdir(parents=True, exist_ok=True) + self.train = [] + self.val = [] + self.test = [] + + for atr, default in self._default_keys.items(): + setattr(self, atr, default) + + for atr, value in kwargs.items(): + if atr in self._default_keys.keys(): + setattr(self, atr, value) + else: + raise ValueError(f'Attribute {atr} does not exists for {self.__class__.__name__}.') + + @classmethod + def from_dict(cls, dataset_dict: dict): + """ + Create a DatasetHandler from a dictionary. + """ + assert 'name' in dataset_dict.keys(), 'Dataset name not found' + assert 'dataset_dir' in dataset_dict.keys(), 'Dataset directory not found' + dataset = cls(name=dataset_dict['name'], dataset_dir=dataset_dict['dataset_dir']) + + for atr, default in cls._default_keys.items(): + if atr in dataset_dict.keys(): + setattr(dataset, atr, dataset_dict.get(atr)) + else: + warn(f'Attribute {atr} not found in dataset {dataset.name}. Setting to default value {default}') + setattr(dataset, atr, default) + + dataset.update_stack_from_disk() + return dataset + + def to_dict(self) -> dict: + """ + Convert a DatasetHandler to a dictionary for serialization. + """ + dataset_dict = {'name': self.name, 'dataset_dir': str(self.dataset_dir)} + for atr in self._default_keys.keys(): + dataset_dict[atr] = getattr(self, atr) + return dataset_dict + + def __repr__(self) -> str: + return f'DatasetHandler {self.name}, location: {self.dataset_dir}' + + def info(self) -> str: + """ + Nice print of the dataset information. + """ + info = f'{self.__repr__()}:\n' + for atr in self._default_keys.keys(): + if atr in self.default_phases: + info += f' {atr}: #{len(getattr(self, atr))} stacks\n' + else: + info += f' {atr}: {getattr(self, atr)}\n' + return info + + def validate(self, *dataset_validators: DatasetValidator) -> tuple[bool, str]: + """ + Validate the dataset using the dataset validators. + Returns: + (bool, str): a boolean (True for success) and a message with the result of the validation + """ + return ComposeDatasetValidators(*dataset_validators)(self) + + def validate_stack(self, *stack_validators: StackValidator) -> tuple[bool, str]: + """ + Validate all the stacks in the dataset using the stack validators. + Returns: + (bool, str): a boolean (True for success) and a message with the result of the validation + """ + files = self.find_stored_files() + return ComposeStackValidators(*stack_validators).batch_apply(files) + + def update_stack_from_disk(self, *validators: StackValidator, phase: str = None): + """ + Update the stacks in the dataset from the disk. + """ + if phase is None: + phases = self.default_phases + else: + phases = [phase] + + for phase in phases: + stacks = self.find_stored_files(phase=phase) + result, msg = ComposeStackValidators(*validators).batch_apply(stacks) + if result: + stacks = [stack.name for stack in stacks] + setattr(self, phase, stacks) + else: + warn(f'Update failed for {phase} phase. {msg}') + + def find_stored_files(self, phase: str = None, ignore_default_file_format: bool = False) -> list[Path]: + """ + Find files in the dataset directory, by default it will only look at the defaults file extensions. + Args: + phase: a string with the phase of the dataset, if None all phases are searched + ignore_default_file_format: set to True to ignore the default file format + Returns: + a list of paths to the stacks found + """ + if phase is None: + phases = self.default_phases + elif isinstance(phase, str): + assert phase in self.default_phases, f'Phase {phase} not found in {self.default_phases}' + phases = (phase,) + + else: + raise ValueError(f'Phase must be a string or None, found {type(phase)}') + + found_files = [] + file_formats = self.default_file_formats if not ignore_default_file_format else ('*',) + + for phase in phases: + phase_dir = self.dataset_dir / phase + assert phase_dir.exists(), f'Phase {phase} not found in {self.dataset_dir}' + + for file_format in file_formats: + stacks_found = [file for file in phase_dir.glob(f'*{file_format}')] + found_files.extend(stacks_found) + + return found_files + + def find_stacks(self, phase: str = None) -> list[str]: + """ + Find the name of the stacks in the dataset directory. + """ + stacks = self.find_stored_files(phase=phase) + return [stack.stem for stack in stacks] + + def add_stack(self, stack_name: str, + phase: str, + data: Stack, + unique_name=True): + """ + Add a stack to the dataset. + Args: + stack_name: string with the name of the stack + phase: string with the phase of the dataset (train, val, test) + data: dictionary with the data to be saved in the stack + {'raw': raw_data, 'labels': labels_data, etc...} + unique_name: if True, the stack name will be changed to a unique name if already exists, + otherwise it will error out. + + Returns: None + """ + phase_dir = self.dataset_dir / phase + stack_path = phase_dir / f'{stack_name}.h5' + idx = 1 + while stack_path.exists() and unique_name: + stack_name += f'_{idx}' + stack_path = phase_dir / f'{stack_name}.h5' + idx += 1 + + data.dump_to_h5(stack_path) + + def remove_stack(self, stack_name: str): + """ + Remove a stack from the dataset. + Args: + stack_name: string with the name of the stack + + Returns: None + """ + for phase in self.default_phases: + stacks = self.find_stacks(phase=phase) + if stack_name in stacks: + stack_path = self.dataset_dir / phase / f'{stack_name}.h5' + if stack_path.exists(): + stack_path.unlink() + self.update_stack_from_disk(phase=phase) + return None + else: + raise FileNotFoundError(f'Stack {stack_name} not found in {phase} phase.') + + raise ValueError(f'Stack {stack_name} not found in dataset {self.name}.') + + def rename_stack(self, stack_name: str, new_name: str): + """ + Rename a stack from the dataset. + Args: + stack_name: string with the name of the stack + new_name: string with the new name of the stack + + Returns: None + """ + for phase in self.default_phases: + stacks = self.find_stacks(phase=phase) + if stack_name in stacks: + stack_path = self.dataset_dir / phase / f'{stack_name}.h5' + if stack_path.exists(): + new_stack_path = self.dataset_dir / phase / f'{new_name}.h5' + stack_path.rename(new_stack_path) + self.update_stack_from_disk(phase=phase) + return None + else: + raise FileNotFoundError(f'Stack {stack_name} not found in {phase} phase.') + + raise ValueError(f'Stack {stack_name} not found in dataset {self.name}.') + + +def load_dataset(dataset_name: str) -> DatasetHandler: + """ + Load a dataset from the user dataset config file. + Args: + dataset_name: string with the name of the dataset + + Returns: + a DatasetHandler object + """ + if dataset_name not in list_datasets(): + raise ValueError(f'Dataset {dataset_name} not found in existing datasets: {list_datasets()}') + + dataset_dict = get_dataset_dict(dataset_name) + dataset = DatasetHandler.from_dict(dataset_dict) + return dataset + + +def save_dataset(dataset: DatasetHandler): + """ + Save a dataset to the user dataset config file. + Args: + dataset: a DatasetHandler object + + Returns: None + """ + dump_dataset_dict(dataset.name, dataset.to_dict()) + + +def delete_dataset(dataset: DatasetHandler): + """ + Delete a dataset from the user dataset config file and delete all the files + Args: + dataset: a DatasetHandler object + + Returns: None + """ + delist_dataset(dataset.name) + rmtree(dataset.dataset_dir) diff --git a/plantseg/dataset_tools/images.py b/plantseg/dataset_tools/images.py new file mode 100644 index 00000000..e6f3532c --- /dev/null +++ b/plantseg/dataset_tools/images.py @@ -0,0 +1,309 @@ +from pathlib import Path +from typing import Union + +import h5py +import numpy as np + +from plantseg.io.h5 import create_h5, load_h5 + + +class GenericImage: + dimensionality: str + key: str + num_channels: int + data: Union[np.ndarray, h5py.Dataset] + layout: str = 'xy' + + def __init__(self, data: np.ndarray, + key: str = 'raw', + dimensionality: str = '3D'): + """ + Generic image class to handle 2D and 3D images consistently. + Args: + data (np.ndarray): image data + key (str): key to use when saving to h5 + dimensionality (str): '2D' or '3D' + """ + assert dimensionality in ['2D', '3D'], f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.' + self.data = data + self.key = key + self.dimensionality = dimensionality + self.num_channels = 1 + + def __repr__(self): + return (f'{self.__class__.__name__}(dimensionality={self.dimensionality},' + f' shape={self.shape},' + f' layout={self.layout})') + + @property + def ndim(self) -> int: + return self.data.ndim + + @property + def shape(self) -> tuple[int, ...]: + return self.data.shape + + @property + def clean_shape(self) -> tuple[int, ...]: + """ + Returns the shape without singleton dimensions and channel dimensions + """ + assert len(self.shape) == len(self.layout), f'Shape and layout do not match: {self.shape} vs {self.layout}.' + clean_shape = [s for s, l in zip(self.shape, self.layout) if 'c' != l and '1' != l] + return tuple(clean_shape) + + def load_data(self): + """ + Load the data from the h5 file + """ + if isinstance(self.data, h5py.Dataset): + self.data = self.data[...] + + def remove_singletons(self, remove_channel: bool = False): + """ + Remove singleton dimensions from the data, and optionally remove the + channel dimension if channel dimension is 1. + Args: + remove_channel (bool): if True, remove the channel dimension if it is 1. + Returns: + GenericImage: self + + """ + data = self.data + axis_to_squeeze = [] + for i, l in enumerate(self.layout): + if 'c' == l and remove_channel: + axis_to_squeeze.append(i) + if '1' == l: + axis_to_squeeze.append(i) + + if axis_to_squeeze: + data = np.squeeze(data, axis=tuple(axis_to_squeeze)) + return type(self)(data=data, key=self.key, dimensionality=self.dimensionality) + + return self + + @classmethod + def from_h5(cls, path: Union[str, Path], key: str, dimensionality: str, load_data: bool = False): + """ + Instantiate a GenericImage from a h5 file + """ + if load_data: + data, _ = load_h5(path=path, key=key) + else: + data = h5py.File(path, mode='r')[key] + return cls(data=data, key=key, dimensionality=dimensionality) + + def to_h5(self, path: Union[str, Path], + voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), + mode: str = 'a'): + """ + Save the data to a h5 file + Args: + path (str, Path): path to the h5 file + voxel_size (tuple): voxel size + mode (str): 'a' to append to an existing file, 'w' to overwrite an existing file + """ + self.load_data() + create_h5(path, stack=self.data, key=self.key, voxel_size=voxel_size, mode=mode) + + +class Image(GenericImage): + def __init__(self, data: np.ndarray, + key: str = 'raw', + dimensionality: str = '3D'): + """ + Args: + data: numpy array + key: internal key of the dataset + dimensionality: 2D or 3D + """ + super().__init__(data=data, key=key, dimensionality=dimensionality) + + if dimensionality == '2D': + if data.ndim == 2: + self.num_channels = 1 + self.layout = 'xy' + elif data.ndim == 3: + self.num_channels = data.shape[0] + self.layout = 'cxy' + elif data.ndim == 4: + self.num_channels = data.shape[0] + assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' + self.layout = 'c1xy' + else: + raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') + + elif dimensionality == '3D': + if data.ndim == 3: + self.num_channels = 1 + self.layout = 'xyz' + elif data.ndim == 4: + self.num_channels = data.shape[0] + self.layout = 'cxyz' + raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') + + else: + raise ValueError(f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.') + + +class Labels(GenericImage): + def __init__(self, data: np.ndarray, + key: str = 'labels', + dimensionality: str = '3D'): + """ + Args: + data: numpy array + key: internal key of the dataset + dimensionality: 2D or 3D + """ + super().__init__(data=data, key=key, dimensionality=dimensionality) + + if dimensionality == '2D': + if data.ndim == 2: + self.num_channels = 1 + self.layout = 'xy' + elif data.ndim == 3: + assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' + self.layout = '1xy' + elif data.ndim == 4: + assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' + assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' + self.layout = '11xy' + else: + raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') + + elif dimensionality == '3D': + if data.ndim == 3: + self.num_channels = 1 + self.layout = 'xyz' + elif data.ndim == 4: + assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' + self.layout = '1xyz' + + else: + raise ValueError(f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.') + + +class Stack: + dimensionality: str + layout: str + data: {} + + def __init__(self, *images: GenericImage, + dimensionality: str = '3D', + strict: bool = True): + """ + Args: + *images (GenericImage): list of images + dimensionality (str): 2D or 3D + strict (bool): if True, raise an error if the images do not have the same dimensionality + """ + self.dimensionality = dimensionality + + data = {} + dimensionality = images[0].dimensionality + for image in images: + assert image.dimensionality == dimensionality, (f'Invalid dimensionality: {image.dimensionality},' + f' all images must have the same dimensionality.') + data[image.key] = image + + self.data = data + result, msg = self.validate() + + if not result and strict: + raise ValueError(msg) + + @property + def keys(self) -> list[str]: + """ + list all the keys of the stack + """ + return list(self.data.keys()) + + @property + def clean_shape(self) -> tuple[int, ...]: + """ + Return the shape of the stack without the channel dimension and singleton dimensions + """ + key = self.keys[0] + return self.data[key].clean_shape + + def validate_dimensionality(self) -> tuple[bool, str]: + for image in self.data.values(): + if image.dimensionality != self.dimensionality: + msg = (f'Invalid dimensionality: {image.dimensionality}, all' + f' images must have the same dimensionality.') + return False, msg + + return True, '' + + def validate_layout(self) -> tuple[bool, str]: + for type_image in [Image, Labels]: + list_image = [image for image in self.data.values() if isinstance(image, type_image)] + if list_image: + layout = list_image[0].layout + for image in list_image: + if image.layout != layout: + msg = (f'Invalid layout: {image.layout}, all' + f' images of type {self.__class__.__name__} must have the same layout.') + return False, msg + return True, '' + + def validate_shape(self) -> tuple[bool, str]: + if len(self.data) == 0: + return False, 'Empty stack.' + + for image in self.data.values(): + if image.clean_shape != self.clean_shape: + msg = (f'Invalid clean shape: {image.shape},' + f' all images must have the clean same' + f' shape {self.clean_shape}.') + return False, msg + return True, '' + + def validate(self) -> tuple[bool, str]: + """ + Validate the stack to ensure that all images have the same dimensionality, layout and shape. + """ + for test in [self.validate_dimensionality, self.validate_layout, self.validate_shape]: + result, msg = test() + if not result: + return False, msg + return True, '' + + def dump_to_h5(self, path: Union[str, Path], mode: str = 'a'): + """ + Dump the full stack to an HDF5 file. + Args: + path: path to the HDF5 file + mode: write mode, one of ['w', 'a', 'r+', 'w-'] + + """ + assert mode in ['w', 'a', 'r+', 'w-'], f'Invalid mode: {mode}, valid values are: [w, a, r+, w-].' + for key, stack in self.data.items(): + stack.to_h5(path=path, mode=mode) + # switch to append mode after first iteration + mode = 'a' + + @classmethod + def from_h5(cls, path: Union[str, Path], + keys: tuple[tuple[str, GenericImage]], + dimensionality: str, + load_data: bool = False, + strict: bool = True): + """ + Load the full stack from an HDF5 file. + Args: + path: path to the HDF5 file + keys: list of (keys, type of data) to load + dimensionality: 2D or 3D + load_data: if True, load the data from the HDF5 file + strict: if True, raise an error if the images do not have the same dimensionality + """ + data = [] + for key, type_image in keys: + im = type_image.from_h5(path=path, key=key, dimensionality=dimensionality, load_data=load_data) + data.append(im) + + return cls(*data, dimensionality=dimensionality, strict=strict) diff --git a/plantseg/dataset_tools/validators.py b/plantseg/dataset_tools/validators.py new file mode 100644 index 00000000..7c63e00a --- /dev/null +++ b/plantseg/dataset_tools/validators.py @@ -0,0 +1,75 @@ +from pathlib import Path +from typing import Union + +from plantseg.dataset_tools.dataset_handler import DatasetHandler +from plantseg.io.h5 import list_keys + + +class CheckDatasetDirectoryStructure: + + def __call__(self, dataset: DatasetHandler) -> tuple[bool, str]: + # Check if dataset directory exists + if not dataset.dataset_dir.exists(): + return False, f'Dataset directory {dataset.dataset_dir} does not exist.' + + # Check if dataset directory contains all expected subdirectories + for phase in dataset.default_phases: + if not (dataset.dataset_dir / phase).exists(): + return False, f'Dataset directory {dataset.dataset_dir} does not contain {phase} directory.' + + return True, '' + + +class CheckH5Keys: + def __init__(self, expected_h5_keys: tuple[str, ...] = ('raw', 'labels')): + self.expected_h5_keys = expected_h5_keys + + def __call__(self, stack: Union[str, Path]) -> tuple[bool, str]: + found_keys = list_keys(stack) + for key in self.expected_h5_keys: + if key not in found_keys: + return False, f'Key {key} not found in {stack}. Expected keys: {self.expected_h5_keys}' + + return True, '' + + +class CheckH5shapes: + def __init__(self, dimensionality: str = '3D', + expected_h5_keys: tuple[str, ...] = (('raw', 'image'), + ('labels', 'labels') + )): + """ + Check if the shape of the data in the h5 file matches the expected shape. + Args: + dimensionality: '2D' or '3D' + expected_h5_keys: tuple of tuples, each tuple contains the key and the expected type of data + possible types are: 'image', 'labels' + """ + assert dimensionality in ['2D', '3D'], f'Invalid dimensionality: {dimensionality}, ' \ + f'valid values are: 2D, 3D' + + self.expected_shapes = {} + if dimensionality == '2D': + for key, data_type in expected_h5_keys: + assert data_type in ['image', 'labels'], f'Invalid data type: {data_type}, ' \ + f'valid values are: image, labels' + if data_type == 'image': + self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, + {'ndim': 3, 'shape': 'cxy'}, + {'ndim': 4, 'shape': 'c1xy'}, + {'ndim': 4, 'shape': '1xy'}] + elif data_type == 'labels': + self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, + {'ndim': 3, 'shape': '1xy'}] + elif dimensionality == '3D': + for key, data_type in expected_h5_keys: + assert data_type in ['image', 'labels'], f'Invalid data type: {data_type}, ' \ + f'valid values are: image, labels' + if data_type == 'image': + self.expected_shapes[key] = [{'ndim': 3, 'shape': 'zxy'}, + {'ndim': 4, 'shape': 'czxy'}, + {'ndim': 4, 'shape': '1xy'}, + {'ndim': 4, 'shape': '1xy'}] + elif data_type == 'labels': + self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, + {'ndim': 3, 'shape': '1xy'}] diff --git a/plantseg/io/h5.py b/plantseg/io/h5.py index b67dd740..8fe4dbd7 100644 --- a/plantseg/io/h5.py +++ b/plantseg/io/h5.py @@ -1,6 +1,7 @@ import warnings from typing import Optional, Union +from pathlib import Path import h5py import numpy as np @@ -51,7 +52,7 @@ def visitor_func(name, node): f"plantseg expects only one dataset to be present in input H5.") -def load_h5(path: str, +def load_h5(path: Union[str, Path], key: str, slices: Optional[slice] = None, info_only: bool = False) -> Union[tuple, tuple[np.array, tuple]]: @@ -67,7 +68,7 @@ def load_h5(path: str, Returns: Union[tuple, tuple[np.array, tuple]]: dataset as numpy array and infos """ - with h5py.File(path, 'r') as f: + with h5py.File(path, mode='r') as f: if key is None: key = _find_input_key(f) @@ -83,7 +84,7 @@ def load_h5(path: str, return file, infos -def create_h5(path: str, +def create_h5(path: Union[str, Path], stack: np.array, key: str, voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), @@ -101,7 +102,7 @@ def create_h5(path: str, None """ - with h5py.File(path, mode) as f: + with h5py.File(path, mode=mode) as f: if key in f: del f[key] f.create_dataset(key, data=stack, compression='gzip') @@ -109,7 +110,7 @@ def create_h5(path: str, f[key].attrs['element_size_um'] = voxel_size -def list_keys(path): +def list_keys(path: Union[str, Path]) -> list[str]: """ List all keys in a h5 file Args: @@ -129,23 +130,23 @@ def _recursive_find_keys(f, base='/'): _list_keys.append(f'{base}{key}') return _list_keys - with h5py.File(path, 'r') as h5_f: + with h5py.File(path, mode='r') as h5_f: return _recursive_find_keys(h5_f) -def del_h5_key(path: str, key: str, mode: str = 'a') -> None: +def del_h5_key(path: Union[str, Path], key: str, mode: str = 'a') -> None: """ helper function to delete a dataset from a h5file """ - with h5py.File(path, mode) as f: + with h5py.File(path, mode=mode) as f: if key in f: del f[key] f.close() -def rename_h5_key(path: str, old_key: str, new_key: str, mode='r+') -> None: +def rename_h5_key(path: Union[str, Path], old_key: str, new_key: str, mode='r+') -> None: """ Rename the 'old_key' dataset to 'new_key' """ - with h5py.File(path, mode) as f: + with h5py.File(path, mode=mode) as f: if old_key in f: f[new_key] = f[old_key] del f[old_key] diff --git a/plantseg/legacy_gui/gui_tools.py b/plantseg/legacy_gui/gui_tools.py index 520b348b..ee87a84d 100644 --- a/plantseg/legacy_gui/gui_tools.py +++ b/plantseg/legacy_gui/gui_tools.py @@ -24,7 +24,7 @@ class SimpleEntry: - """ Standard open entry widget """ + """ Standard open entry widgets """ def __init__(self, frame, text="Text", large_bar=False, row=0, column=0, _type=str, _font=None): self.frame = tkinter.Frame(frame) @@ -77,7 +77,7 @@ def __call__(self, value, obj_collection): class SliderEntry: - """ Standard open entry widget """ + """ Standard open entry widgets """ def __init__(self, frame, text="Text", row=0, column=0, data_range=(0.01, 0.99, 0.01), @@ -135,7 +135,7 @@ def __call__(self, value, obj_collection): class MenuEntry: - """ Standard menu widget """ + """ Standard menu widgets """ def __init__(self, frame, text="Text", row=0, column=0, menu=(), is_model=False, is_segmentation=False, default=None, font=None): @@ -216,7 +216,7 @@ def update_segmentation_name(self, value): class BoolEntry: - """ Standard boolean widget """ + """ Standard boolean widgets """ def __init__(self, frame, text="Text", row=0, column=0, font=None): self.frame = tkinter.Frame(frame) @@ -263,7 +263,7 @@ def __call__(self, value, obj_collection): class FilterEntry: - """ Special widget for filter """ + """ Special widgets for filter """ def __init__(self, frame, text="Text", row=0, column=0, font=None): self.frame = tkinter.Frame(frame) @@ -333,7 +333,7 @@ def __call__(self, value, obj_collection): class MenuEntryStride: - """ Standard menu widget """ + """ Standard menu widgets """ def __init__(self, frame, text="Text", row=0, column=0, menu=(), is_model=False, default=None, font=None): self.frame = tkinter.Frame(frame) @@ -401,7 +401,7 @@ def update_model_name(self, value): class RescaleEntry: - """ Special widget for rescale """ + """ Special widgets for rescale """ def __init__(self, frame, text="Text", row=0, column=0, type=float, font=None): self.frame = tkinter.Frame(frame) @@ -489,7 +489,7 @@ def auto_rescale(self): class ListEntry: - """ Standard triplet list widget """ + """ Standard triplet list widgets """ def __init__(self, frame, text="Text", row=0, column=0, type=float, font=None): self.frame = tkinter.Frame(frame) diff --git a/plantseg/legacy_gui/plantsegapp.py b/plantseg/legacy_gui/plantsegapp.py index 55edd99e..7e43c379 100644 --- a/plantseg/legacy_gui/plantsegapp.py +++ b/plantseg/legacy_gui/plantsegapp.py @@ -420,7 +420,7 @@ def open_postprocessing(): webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki/Classic-Data-Processing") def size_up(self): - """ adjust font size in the main widget""" + """ adjust font size in the main widgets""" self.font_size += 2 self.font_size = min(100, self.font_size) self.update_font(self.font_size) @@ -429,7 +429,7 @@ def size_up(self): self.build_all() def size_down(self): - """ adjust font size in the main widget""" + """ adjust font size in the main widgets""" self.font_size -= 2 self.font_size = max(0, self.font_size) self.update_font(self.font_size) diff --git a/plantseg/pipeline/utils.py b/plantseg/pipeline/utils.py index c0142bdf..dab7baad 100644 --- a/plantseg/pipeline/utils.py +++ b/plantseg/pipeline/utils.py @@ -40,7 +40,7 @@ def load_paths(base_path): class QueueHandler(logging.Handler): """Class to send logging records to a queue It can be used from different threads - The ConsoleUi class polls this queue to display records in a ScrolledText widget + The ConsoleUi class polls this queue to display records in a ScrolledText widgets """ def __init__(self, log_queue): diff --git a/plantseg/run_plantseg.py b/plantseg/run_plantseg.py index 113a4cbd..d83696ee 100644 --- a/plantseg/run_plantseg.py +++ b/plantseg/run_plantseg.py @@ -25,15 +25,15 @@ def main(): PlantSegApp() elif args.napari: - from plantseg.viewer.viewer import run_viewer + from plantseg.ui.viewer import run_viewer run_viewer() elif args.training: - from plantseg.viewer.training import run_training_headless + from plantseg.ui.training import run_training_headless run_training_headless() elif args.headless: - from plantseg.viewer.headless import run_workflow_headless + from plantseg.ui.headless import run_workflow_headless run_workflow_headless(args.headless) elif args.config is not None: @@ -52,7 +52,7 @@ def main(): else: raise ValueError("Not enough arguments. Please use: \n" - " --napari for launching the napari image viewer or \n" + " --napari for launching the napari image ui or \n" " --training for launching the training configurator or \n" " --headless 'path_to_workflow.pkl' for launching a saved workflow or \n" " --gui for launching the graphical pipeline configurator or \n" diff --git a/plantseg/viewer/__init__.py b/plantseg/ui/__init__.py similarity index 100% rename from plantseg/viewer/__init__.py rename to plantseg/ui/__init__.py diff --git a/plantseg/viewer/containers.py b/plantseg/ui/containers.py similarity index 70% rename from plantseg/viewer/containers.py rename to plantseg/ui/containers.py index e82bc17e..b6c0735c 100644 --- a/plantseg/viewer/containers.py +++ b/plantseg/ui/containers.py @@ -3,20 +3,20 @@ from PyQt5.QtCore import Qt from magicgui.widgets import MainWindow -from plantseg.viewer.widget.dataprocessing import widget_cropping, widget_add_layers -from plantseg.viewer.widget.dataprocessing import widget_label_processing -from plantseg.viewer.widget.dataprocessing import widget_rescaling, widget_gaussian_smoothing -from plantseg.viewer.widget.io import open_file, export_stacks -from plantseg.viewer.widget.predictions import widget_iterative_unet_predictions, widget_add_custom_model -from plantseg.viewer.widget.predictions import widget_unet_predictions, widget_test_all_unet_predictions -from plantseg.viewer.widget.proofreading.proofreading import widget_clean_scribble, widget_filter_segmentation -from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles -from plantseg.viewer.widget.segmentation import widget_dt_ws, widget_agglomeration -from plantseg.viewer.widget.segmentation import widget_fix_over_under_segmentation_from_nuclei -from plantseg.viewer.widget.segmentation import widget_lifted_multicut -from plantseg.viewer.widget.segmentation import widget_simple_dt_ws -from plantseg.viewer.widget.dataset_handling import widget_create_dataset -from plantseg.viewer.widget.dataset_handling import widget_add_stack, widget_delete_dataset, widget_validata_dataset +from plantseg.ui.widgets.dataprocessing import widget_cropping, widget_add_layers +from plantseg.ui.widgets.dataprocessing import widget_label_processing +from plantseg.ui.widgets.dataprocessing import widget_rescaling, widget_gaussian_smoothing +from plantseg.ui.widgets.io import open_file, export_stacks +from plantseg.ui.widgets.predictions import widget_iterative_unet_predictions, widget_add_custom_model +from plantseg.ui.widgets.predictions import widget_unet_predictions, widget_test_all_unet_predictions +from plantseg.ui.widgets.proofreading.proofreading import widget_clean_scribble, widget_filter_segmentation +from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles +from plantseg.ui.widgets.segmentation import widget_dt_ws, widget_agglomeration +from plantseg.ui.widgets.segmentation import widget_fix_over_under_segmentation_from_nuclei +from plantseg.ui.widgets.segmentation import widget_lifted_multicut +from plantseg.ui.widgets.segmentation import widget_simple_dt_ws +from plantseg.ui.widgets.dataset_tools import widget_create_dataset +from plantseg.ui.widgets.dataset_tools import widget_add_stack, widget_delete_dataset, widget_validata_dataset def setup_menu(container, path=None): diff --git a/plantseg/viewer/dag_handler.py b/plantseg/ui/dag_handler.py similarity index 100% rename from plantseg/viewer/dag_handler.py rename to plantseg/ui/dag_handler.py diff --git a/plantseg/viewer/headless.py b/plantseg/ui/headless.py similarity index 96% rename from plantseg/viewer/headless.py rename to plantseg/ui/headless.py index 83bc2133..d9c03e4b 100644 --- a/plantseg/viewer/headless.py +++ b/plantseg/ui/headless.py @@ -7,8 +7,8 @@ from magicgui import magicgui from tqdm import tqdm -from plantseg.viewer.dag_handler import DagHandler -from plantseg.viewer.widget.predictions import ALL_DEVICES, ALL_CUDA_DEVICES +from plantseg.ui.dag_handler import DagHandler +from plantseg.ui.widgets.predictions import ALL_DEVICES, ALL_CUDA_DEVICES all_gpus_str = f'all gpus: {len(ALL_CUDA_DEVICES)}' ALL_GPUS = [all_gpus_str] if len(ALL_CUDA_DEVICES) > 0 else [] diff --git a/plantseg/viewer/logging.py b/plantseg/ui/logging.py similarity index 100% rename from plantseg/viewer/logging.py rename to plantseg/ui/logging.py diff --git a/plantseg/viewer/training.py b/plantseg/ui/training.py similarity index 52% rename from plantseg/viewer/training.py rename to plantseg/ui/training.py index 005195fd..d7e2de17 100644 --- a/plantseg/viewer/training.py +++ b/plantseg/ui/training.py @@ -1,4 +1,4 @@ -from plantseg.viewer.widget.training import widget_unet_training +from plantseg.ui.widgets.training import widget_unet_training def run_training_headless(): diff --git a/plantseg/viewer/viewer.py b/plantseg/ui/viewer.py similarity index 69% rename from plantseg/viewer/viewer.py rename to plantseg/ui/viewer.py index 6c882825..ea6c0f40 100644 --- a/plantseg/viewer/viewer.py +++ b/plantseg/ui/viewer.py @@ -1,10 +1,10 @@ import napari -from plantseg.viewer.containers import get_extra_seg, get_extra_pred -from plantseg.viewer.containers import get_gasp_workflow, get_preprocessing_workflow, get_main -from plantseg.viewer.containers import get_dataset_workflow -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.proofreading.proofreading import setup_proofreading_keybindings +from plantseg.ui.containers import get_extra_seg, get_extra_pred +from plantseg.ui.containers import get_gasp_workflow, get_preprocessing_workflow, get_main +from plantseg.ui.containers import get_dataset_workflow +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.proofreading.proofreading import setup_proofreading_keybindings def run_viewer(): diff --git a/plantseg/viewer/widget/proofreading/__init__.py b/plantseg/ui/widgets/__init__.py similarity index 100% rename from plantseg/viewer/widget/proofreading/__init__.py rename to plantseg/ui/widgets/__init__.py diff --git a/plantseg/viewer/widget/dataprocessing.py b/plantseg/ui/widgets/dataprocessing.py similarity index 97% rename from plantseg/viewer/widget/dataprocessing.py rename to plantseg/ui/widgets/dataprocessing.py index ec9c7ae2..8226b798 100644 --- a/plantseg/viewer/widget/dataprocessing.py +++ b/plantseg/ui/widgets/dataprocessing.py @@ -13,10 +13,10 @@ from plantseg.dataprocessing.functional.labelprocessing import relabel_segmentation as _relabel_segmentation from plantseg.dataprocessing.functional.labelprocessing import set_background_to_value from plantseg.utils import list_models, get_model_resolution -from plantseg.viewer.widget.predictions import widget_unet_predictions -from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws -from plantseg.viewer.widget.utils import return_value_if_widget -from plantseg.viewer.widget.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.ui.widgets.predictions import widget_unet_predictions +from plantseg.ui.widgets.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws +from plantseg.ui.widgets.utils import return_value_if_widget +from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties @magicgui(call_button='Run Gaussian Smoothing', @@ -193,7 +193,7 @@ def _cropping(data, crop_slices): crop_roi={'label': 'Crop ROI', 'tooltip': 'This must be a shape layer with a rectangle XY overlaying the area to crop.'}, # FloatRangeSlider and RangeSlider are not working very nicely with napari, they are usable but not very - # nice. maybe we should use a custom widget for this. + # nice. maybe we should use a custom widgets for this. crop_z={'label': 'Z slices', 'tooltip': 'Numer of z slices to take next to the current selection.', 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 1, diff --git a/plantseg/viewer/widget/dataset_handling.py b/plantseg/ui/widgets/dataset_tools.py similarity index 95% rename from plantseg/viewer/widget/dataset_handling.py rename to plantseg/ui/widgets/dataset_tools.py index 1560977f..35ae6f02 100644 --- a/plantseg/viewer/widget/dataset_handling.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -5,8 +5,8 @@ from plantseg import PLANTSEG_MODELS_DIR from plantseg.io import create_h5 -from plantseg.utils import list_datasets, save_dataset, get_dataset, delete_dataset -from plantseg.viewer.logging import napari_formatted_logging +from plantseg.utils import list_datasets, dump_dataset_dict, get_dataset_dict, delist_dataset +from plantseg.ui.logging import napari_formatted_logging empty_dataset = ['none'] startup_list_datasets = list_datasets() or empty_dataset @@ -38,7 +38,7 @@ def widget_create_dataset(dataset_name: str = 'my-dataset', dataset_dir: Path = } if dataset_name not in list_datasets(): - save_dataset(dataset_name, new_dataset) + dump_dataset_dict(dataset_name, new_dataset) return new_dataset raise ValueError(f'Dataset {dataset_name} already exists.') @@ -59,7 +59,7 @@ def widget_add_stack(dataset_name: str = startup_list_datasets[0], labels: Labels = None, phase: str = 'train', is_sparse: bool = False): - dataset_config = get_dataset(dataset_name) + dataset_config = get_dataset_dict(dataset_name) if image is None or labels is None: napari_formatted_logging(message=f'To add a stack to the dataset, please select an image and a labels layer.', @@ -130,7 +130,7 @@ def widget_add_stack(dataset_name: str = startup_list_datasets[0], image_path = str(dataset_dir / f'{stack_name}.h5') create_h5(image_path, image_data, key=dataset_config['image_key']) create_h5(image_path, labels_data, key=dataset_config['labels_key']) - save_dataset(dataset_name, dataset_config) + dump_dataset_dict(dataset_name, dataset_config) napari_formatted_logging(message=f'Stack {stack_name} added to dataset {dataset_name}.', thread='widget_add_stack', level='info') @@ -142,7 +142,7 @@ def widget_add_stack(dataset_name: str = startup_list_datasets[0], 'tooltip': f'Name of the dataset to be validated'}, ) def widget_validata_dataset(dataset_name: str = startup_list_datasets[0]): - dataset_config = get_dataset(dataset_name) + dataset_config = get_dataset_dict(dataset_name) # check all stacks are present dataset_dir = Path(dataset_config['dataset_dir']) @@ -171,7 +171,7 @@ def widget_validata_dataset(dataset_name: str = startup_list_datasets[0]): 'tooltip': f'Name of the dataset to be deleted'}, ) def widget_delete_dataset(dataset_name: str = startup_list_datasets[0]): - delete_dataset(dataset_name) + delist_dataset(dataset_name) @widget_create_dataset.called.connect diff --git a/plantseg/viewer/widget/io.py b/plantseg/ui/widgets/io.py similarity index 98% rename from plantseg/viewer/widget/io.py rename to plantseg/ui/widgets/io.py index cb089d22..7c0dfab2 100644 --- a/plantseg/viewer/widget/io.py +++ b/plantseg/ui/widgets/io.py @@ -14,9 +14,9 @@ from plantseg.io import load_tiff, load_h5, load_pill, load_zarr from plantseg.io.h5 import list_keys as list_h5_keys from plantseg.io.zarr import list_keys as list_zarr_keys -from plantseg.viewer.dag_handler import dag_manager -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.utils import layer_properties, return_value_if_widget +from plantseg.ui.dag_handler import dag_manager +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.utils import layer_properties, return_value_if_widget def _check_layout_string(layout): diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/ui/widgets/predictions.py similarity index 97% rename from plantseg/viewer/widget/predictions.py rename to plantseg/ui/widgets/predictions.py index b1a73c2c..f6d9a71b 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/ui/widgets/predictions.py @@ -14,11 +14,11 @@ from plantseg.predictions.functional import unet_predictions from plantseg.utils import list_all_modality, list_all_dimensionality, list_all_output_type from plantseg.utils import list_models, add_custom_model, get_train_config, get_model_zoo, get_model_description -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles -from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws -from plantseg.viewer.widget.utils import return_value_if_widget -from plantseg.viewer.widget.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles +from plantseg.ui.widgets.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws +from plantseg.ui.widgets.utils import return_value_if_widget +from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] diff --git a/plantseg/ui/widgets/proofreading/__init__.py b/plantseg/ui/widgets/proofreading/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plantseg/viewer/widget/proofreading/proofreading.py b/plantseg/ui/widgets/proofreading/proofreading.py similarity index 95% rename from plantseg/viewer/widget/proofreading/proofreading.py rename to plantseg/ui/widgets/proofreading/proofreading.py index 967cf19d..79529c81 100644 --- a/plantseg/viewer/widget/proofreading/proofreading.py +++ b/plantseg/ui/widgets/proofreading/proofreading.py @@ -8,9 +8,9 @@ from napari.qt.threading import thread_worker from napari.types import LayerDataTuple -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.proofreading.split_merge_tools import split_merge_from_seeds -from plantseg.viewer.widget.proofreading.utils import get_bboxes +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.proofreading.split_merge_tools import split_merge_from_seeds +from plantseg.ui.widgets.proofreading.utils import get_bboxes DEFAULT_KEY_BINDING_PROOFREAD = 'n' DEFAULT_KEY_BINDING_CLEAN = 'b' @@ -191,11 +191,11 @@ def reset_corrected_cells_mask(self): @magicgui(call_button=f'Clean scribbles - < {DEFAULT_KEY_BINDING_CLEAN} >') def widget_clean_scribble(viewer: napari.Viewer): if not segmentation_handler.status: - napari_formatted_logging('Proofreading widget not initialized. Run the proofreading widget tool once first', + napari_formatted_logging('Proofreading widgets not initialized. Run the proofreading widgets tool once first', thread='Clean scribble') if 'Scribbles' not in viewer.layers: - napari_formatted_logging('Scribble Layer not defined. Run the proofreading widget tool once first', + napari_formatted_logging('Scribble Layer not defined. Run the proofreading widgets tool once first', thread='Clean scribble') return None @@ -254,7 +254,7 @@ def widget_split_and_merge_from_scribbles(viewer: napari.Viewer, elif 'pmap' not in image.metadata: napari_formatted_logging('Pmap/Image layer appears to be a raw image and not a boundary probability map. ' 'For the best proofreading results, try to use a boundaries probability layer ' - '(e.g. from the Run Prediction widget)', + '(e.g. from the Run Prediction widgets)', thread='Proofreading tool', level='warning') if initialize_proofreading(viewer, segmentation): @@ -293,9 +293,9 @@ def func(): @magicgui(call_button=f'Extract correct labels') def widget_filter_segmentation() -> Future[LayerDataTuple]: if not segmentation_handler.status: - napari_formatted_logging('Proofreading widget not initialized. Run the proofreading widget tool once first', + napari_formatted_logging('Proofreading widgets not initialized. Run the proofreading widgets tool once first', thread='Export correct labels', level='error') - raise ValueError('Proofreading widget not initialized. Run the proofreading widget tool once first') + raise ValueError('Proofreading widgets not initialized. Run the proofreading widgets tool once first') future = Future() diff --git a/plantseg/viewer/widget/proofreading/split_merge_tools.py b/plantseg/ui/widgets/proofreading/split_merge_tools.py similarity index 95% rename from plantseg/viewer/widget/proofreading/split_merge_tools.py rename to plantseg/ui/widgets/proofreading/split_merge_tools.py index 1bf5f578..27906b2f 100644 --- a/plantseg/viewer/widget/proofreading/split_merge_tools.py +++ b/plantseg/ui/widgets/proofreading/split_merge_tools.py @@ -1,8 +1,8 @@ import numpy as np from skimage.segmentation import watershed -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.proofreading.utils import get_bboxes, get_idx_slice +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.proofreading.utils import get_bboxes, get_idx_slice def _merge_from_seeds(segmentation, region_slice, region_bbox, bboxes, all_idx): diff --git a/plantseg/viewer/widget/proofreading/utils.py b/plantseg/ui/widgets/proofreading/utils.py similarity index 100% rename from plantseg/viewer/widget/proofreading/utils.py rename to plantseg/ui/widgets/proofreading/utils.py diff --git a/plantseg/viewer/widget/segmentation.py b/plantseg/ui/widgets/segmentation.py similarity index 98% rename from plantseg/viewer/widget/segmentation.py rename to plantseg/ui/widgets/segmentation.py index 85695d02..a49f5356 100644 --- a/plantseg/viewer/widget/segmentation.py +++ b/plantseg/ui/widgets/segmentation.py @@ -11,15 +11,15 @@ from plantseg.dataprocessing.functional.dataprocessing import normalize_01 from plantseg.segmentation.functional import gasp, multicut, dt_watershed, mutex_ws from plantseg.segmentation.functional import lifted_multicut_from_nuclei_segmentation, lifted_multicut_from_nuclei_pmaps -from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles -from plantseg.viewer.logging import napari_formatted_logging -from plantseg.viewer.widget.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles +from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties def _pmap_warn(thread: str): napari_formatted_logging('Pmap/Image layer appears to be a raw image and not a pmap. For the best segmentation ' 'results, try to use a boundaries probability layer ' - '(e.g. from the Run Prediction widget)', + '(e.g. from the Run Prediction widgets)', thread=thread, level='warning') diff --git a/plantseg/viewer/widget/training.py b/plantseg/ui/widgets/training.py similarity index 96% rename from plantseg/viewer/widget/training.py rename to plantseg/ui/widgets/training.py index f016b58e..6aaef7c9 100644 --- a/plantseg/viewer/widget/training.py +++ b/plantseg/ui/widgets/training.py @@ -9,8 +9,8 @@ from plantseg import PLANTSEG_MODELS_DIR from plantseg.training.train import unet_training from plantseg.utils import list_all_dimensionality -from plantseg.viewer.widget.predictions import ALL_DEVICES -from plantseg.viewer.widget.utils import create_layer_name, start_threading_process, return_value_if_widget, \ +from plantseg.ui.widgets.predictions import ALL_DEVICES +from plantseg.ui.widgets.utils import create_layer_name, start_threading_process, return_value_if_widget, \ layer_properties diff --git a/plantseg/viewer/widget/utils.py b/plantseg/ui/widgets/utils.py similarity index 96% rename from plantseg/viewer/widget/utils.py rename to plantseg/ui/widgets/utils.py index 40d4fdd1..b39a8a50 100644 --- a/plantseg/viewer/widget/utils.py +++ b/plantseg/ui/widgets/utils.py @@ -7,8 +7,8 @@ from napari import Viewer from napari.qt.threading import thread_worker -from plantseg.viewer.dag_handler import dag_manager -from plantseg.viewer.logging import napari_formatted_logging +from plantseg.ui.dag_handler import dag_manager +from plantseg.ui.logging import napari_formatted_logging def identity(*args, **kwargs): diff --git a/plantseg/utils.py b/plantseg/utils.py index c5301150..441a354a 100644 --- a/plantseg/utils.py +++ b/plantseg/utils.py @@ -278,7 +278,8 @@ def check_models(model_name: str, update_files: bool = False, config_only: bool def clean_models(): for _ in range(3): answer = input("This will delete all models in the model zoo, " - "make sure to copy all custom models you want to preserve before continuing.\n" + "dataset, and config files in the PlantSeg local directory.\n" + "Make sure to copy all files you want to preserve before continuing.\n" "Are you sure you want to continue? (y/n) ") if answer == 'y': shutil.rmtree(PLANTSEG_LOCAL_DIR) @@ -292,6 +293,8 @@ def clean_models(): else: print("Invalid input, please type 'y' or 'n'.") + print("Too many invalid inputs, PlantSeg will now close.") + def check_version(plantseg_url=' https://api.github.com/repos/hci-unihd/plant-seg/releases/latest'): try: @@ -334,7 +337,7 @@ def list_datasets(): return list(datasets.keys()) -def get_dataset(key: str): +def get_dataset_dict(key: str): """ Get a dataset from the user dataset config file """ @@ -344,7 +347,7 @@ def get_dataset(key: str): return datasets[key] -def save_dataset(key: str, dataset: dict): +def dump_dataset_dict(key: str, dataset: dict): """ Save a dataset to the user dataset config file, if the dataset already exists it will be overwritten """ @@ -355,9 +358,9 @@ def save_dataset(key: str, dataset: dict): yaml.dump(datasets, f) -def delete_dataset(key: str): +def delist_dataset(key: str): """ - Delete a dataset from the user dataset config file + Delete a dataset from the user dataset config file but keep the files """ datasets = load_config(USER_DATASETS_CONFIG) if key not in datasets: From 527c70626aa6cb1a156f77fb71fb9a9cb4abf521 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Sat, 12 Aug 2023 19:20:55 +0200 Subject: [PATCH 17/22] finalize first draft of the dataset tools --- conda-recipe/meta.yaml | 1 + plantseg/dataset_tools/dataset_handler.py | 219 +++++------ plantseg/dataset_tools/images.py | 446 +++++++++++++++++----- plantseg/dataset_tools/validators.py | 59 --- plantseg/io/h5.py | 63 ++- plantseg/ui/containers.py | 8 +- plantseg/ui/widgets/dataset_tools.py | 355 +++++++++-------- 7 files changed, 728 insertions(+), 423 deletions(-) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 9247b92c..af27b2b3 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -31,6 +31,7 @@ requirements: - python-elf - napari - python-graphviz + - pyqt test: imports: diff --git a/plantseg/dataset_tools/dataset_handler.py b/plantseg/dataset_tools/dataset_handler.py index e83cfd9d..4ba47905 100644 --- a/plantseg/dataset_tools/dataset_handler.py +++ b/plantseg/dataset_tools/dataset_handler.py @@ -3,7 +3,7 @@ from typing import Union, Protocol from warnings import warn -from plantseg.dataset_tools.images import Stack +from plantseg.dataset_tools.images import Stack, StackSpecs from plantseg.io.h5 import H5_EXTENSIONS from plantseg.utils import dump_dataset_dict, get_dataset_dict, delist_dataset, list_datasets @@ -16,14 +16,6 @@ def __call__(self, dataset: object) -> tuple[bool, str]: ... -class StackValidator(Protocol): - def __init__(self, dataset: object): - ... - - def __call__(self, stack_path: Union[str, Path]) -> tuple[bool, str]: - ... - - class ComposeDatasetValidators: """ Compose multiple dataset validators into a single one. @@ -67,91 +59,41 @@ def batch_apply(self, list_dataset: list[object]) -> tuple[bool, str]: return True, self.success_msg -class ComposeStackValidators: - """ - Compose multiple stack validators into a single one. - """ - success_msg = 'All tests passed.' - - def __init__(self, *validators: StackValidator): - self.validators = validators - - def __call__(self, stack_path: Union[str, Path]) -> tuple[bool, str]: - return self.apply(stack_path) - - def apply(self, stack_path: Union[str, Path]) -> tuple[bool, str]: - """ - Apply all the validators to the stack. - Args: - stack_path: path to the stack to validate - - Returns: - tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. - - """ - for validator in self.validators: - valid, msg = validator(stack_path) - if not valid: - return valid, msg - return True, self.success_msg - - def batch_apply(self, list_stack_path: list[Union[str, Path]]) -> tuple[bool, str]: - """ - Apply all the validators to a list of stacks. - Args: - list_stack_path: list of paths to the stacks to validate - - Returns: - tuple[bool, str]: (valid, msg) where valid is True if all the tests passed, False otherwise. - - """ - for stack_path in list_stack_path: - valid, msg = self.apply(stack_path) - if not valid: - msg = f'Validation failed for {stack_path}.\nWith msg: {msg}' - return False, msg - - return True, self.success_msg - - class DatasetHandler: """ DatasetHandler is a class that contains all the information about a dataset. It is used to create a dataset from a directory, to save a dataset to a directory and manage the dataset. """ - name: str - keys: tuple[str, ...] - default_phases: tuple[str, ...] - default_file_formats: tuple[str, ...] default_phases: tuple[str, ...] = ('train', 'val', 'test') train: list[str] val: list[str] test: list[str] + default_file_formats = H5_EXTENSIONS - _default_keys = {'task': None, - 'dimensionality': None, # 2D or 3D - 'image_channels': None, - 'keys': ('raw', 'labels'), # keys of the h5 file (raw, labels, etc. ) - 'is_sparse': False, - 'default_file_formats': H5_EXTENSIONS - } + def __init__(self, + name: str, + dataset_dir: Union[str, Path], + expected_stack_specs: StackSpecs): - def __init__(self, name: str, dataset_dir: Union[str, Path], **kwargs): + assert isinstance(name, str), 'name must be a string' self.name = name + + assert isinstance(dataset_dir, (str, Path)), 'dataset_dir must be a string or a Path' self.dataset_dir = Path(dataset_dir) - self.dataset_dir.mkdir(parents=True, exist_ok=True) + self.train = [] self.val = [] self.test = [] - for atr, default in self._default_keys.items(): - setattr(self, atr, default) + self.expected_stack_specs = expected_stack_specs + self.init_datastructure() + self.update_stack_from_disk() + self.is_sparse = self._is_sparse() - for atr, value in kwargs.items(): - if atr in self._default_keys.keys(): - setattr(self, atr, value) - else: - raise ValueError(f'Attribute {atr} does not exists for {self.__class__.__name__}.') + def init_datastructure(self): + self.dataset_dir.mkdir(parents=True, exist_ok=True) + for phase in self.default_phases: + (self.dataset_dir / phase).mkdir(exist_ok=True) @classmethod def from_dict(cls, dataset_dict: dict): @@ -160,27 +102,30 @@ def from_dict(cls, dataset_dict: dict): """ assert 'name' in dataset_dict.keys(), 'Dataset name not found' assert 'dataset_dir' in dataset_dict.keys(), 'Dataset directory not found' - dataset = cls(name=dataset_dict['name'], dataset_dir=dataset_dict['dataset_dir']) - - for atr, default in cls._default_keys.items(): - if atr in dataset_dict.keys(): - setattr(dataset, atr, dataset_dict.get(atr)) - else: - warn(f'Attribute {atr} not found in dataset {dataset.name}. Setting to default value {default}') - setattr(dataset, atr, default) + assert 'stack_specs' in dataset_dict.keys(), 'Dataset stack_specs not found' + assert 'list_specs' in dataset_dict['stack_specs'].keys(), 'Dataset list_specs not found' - dataset.update_stack_from_disk() - return dataset + name = dataset_dict['name'] + dataset_dir = dataset_dict['dataset_dir'] + stack_specs = StackSpecs.from_dict(dataset_dict['stack_specs']) + return cls(name=name, dataset_dir=dataset_dir, expected_stack_specs=stack_specs) def to_dict(self) -> dict: """ Convert a DatasetHandler to a dictionary for serialization. """ - dataset_dict = {'name': self.name, 'dataset_dir': str(self.dataset_dir)} - for atr in self._default_keys.keys(): - dataset_dict[atr] = getattr(self, atr) + dataset_dict = { + 'name': self.name, + 'dataset_dir': str(self.dataset_dir), + 'stack_specs': self.expected_stack_specs.to_dict() + } return dataset_dict + def _is_sparse(self, stacks: list[Stack] = None) -> bool: + if stacks is None: + stacks = self.get_stacks() + return all([stack.is_sparse for stack in stacks]) + def __repr__(self) -> str: return f'DatasetHandler {self.name}, location: {self.dataset_dir}' @@ -189,11 +134,9 @@ def info(self) -> str: Nice print of the dataset information. """ info = f'{self.__repr__()}:\n' - for atr in self._default_keys.keys(): - if atr in self.default_phases: - info += f' {atr}: #{len(getattr(self, atr))} stacks\n' - else: - info += f' {atr}: {getattr(self, atr)}\n' + info += f'Dimensionality: {self.expected_stack_specs.dimensionality}\n' + info += f'Is sparse: {self.is_sparse}\n' + info += f'Number of stacks: {len(self.find_stacks_names())}\n' return info def validate(self, *dataset_validators: DatasetValidator) -> tuple[bool, str]: @@ -204,16 +147,17 @@ def validate(self, *dataset_validators: DatasetValidator) -> tuple[bool, str]: """ return ComposeDatasetValidators(*dataset_validators)(self) - def validate_stack(self, *stack_validators: StackValidator) -> tuple[bool, str]: + def get_stack(self, path: Union[str, Path]) -> tuple[Stack, bool, str]: """ - Validate all the stacks in the dataset using the stack validators. + Get a stack from the dataset. + Args: + path: path to the stack Returns: - (bool, str): a boolean (True for success) and a message with the result of the validation + Stack: the stack """ - files = self.find_stored_files() - return ComposeStackValidators(*stack_validators).batch_apply(files) + return Stack.from_h5(path=path, expected_stack_specs=self.expected_stack_specs) - def update_stack_from_disk(self, *validators: StackValidator, phase: str = None): + def update_stack_from_disk(self, phase: str = None): """ Update the stacks in the dataset from the disk. """ @@ -222,14 +166,13 @@ def update_stack_from_disk(self, *validators: StackValidator, phase: str = None) else: phases = [phase] + is_sparse_phase = [] for phase in phases: - stacks = self.find_stored_files(phase=phase) - result, msg = ComposeStackValidators(*validators).batch_apply(stacks) - if result: - stacks = [stack.name for stack in stacks] - setattr(self, phase, stacks) - else: - warn(f'Update failed for {phase} phase. {msg}') + stacks_found = self.get_stacks(phase=phase) + setattr(self, phase, stacks_found) + is_sparse_phase.append(self._is_sparse(stacks_found)) + + self.is_sparse = all(is_sparse_phase) def find_stored_files(self, phase: str = None, ignore_default_file_format: bool = False) -> list[Path]: """ @@ -262,23 +205,38 @@ def find_stored_files(self, phase: str = None, ignore_default_file_format: bool return found_files - def find_stacks(self, phase: str = None) -> list[str]: + def find_stacks_names(self, phase: str = None) -> list[str]: """ Find the name of the stacks in the dataset directory. """ stacks = self.find_stored_files(phase=phase) return [stack.stem for stack in stacks] + def get_stacks(self, phase: str = None) -> list[Stack]: + """ + Get the stacks in the dataset directory. + """ + stacks = self.find_stored_files(phase=phase) + all_stacks = [] + for stack in stacks: + stack, result, msg = self.get_stack(stack) + if result: + all_stacks.append(stack) + else: + warn(f'Stack {stack} seems to not be compatible with the dataset specs. Error {msg}, skipping it.') + + return all_stacks + def add_stack(self, stack_name: str, phase: str, - data: Stack, + stack: Stack, unique_name=True): """ Add a stack to the dataset. Args: stack_name: string with the name of the stack phase: string with the phase of the dataset (train, val, test) - data: dictionary with the data to be saved in the stack + stack: dictionary with the data to be saved in the stack {'raw': raw_data, 'labels': labels_data, etc...} unique_name: if True, the stack name will be changed to a unique name if already exists, otherwise it will error out. @@ -293,7 +251,12 @@ def add_stack(self, stack_name: str, stack_path = phase_dir / f'{stack_name}.h5' idx += 1 - data.dump_to_h5(stack_path) + result, msg = stack.check_compatibility(self.expected_stack_specs) + if result: + stack.dump_to_h5(stack_path) + return None + + raise ValueError(f'Could not add stack to dataset. {msg}') def remove_stack(self, stack_name: str): """ @@ -304,7 +267,7 @@ def remove_stack(self, stack_name: str): Returns: None """ for phase in self.default_phases: - stacks = self.find_stacks(phase=phase) + stacks = self.find_stacks_names(phase=phase) if stack_name in stacks: stack_path = self.dataset_dir / phase / f'{stack_name}.h5' if stack_path.exists(): @@ -326,7 +289,7 @@ def rename_stack(self, stack_name: str, new_name: str): Returns: None """ for phase in self.default_phases: - stacks = self.find_stacks(phase=phase) + stacks = self.find_stacks_names(phase=phase) if stack_name in stacks: stack_path = self.dataset_dir / phase / f'{stack_name}.h5' if stack_path.exists(): @@ -368,13 +331,35 @@ def save_dataset(dataset: DatasetHandler): dump_dataset_dict(dataset.name, dataset.to_dict()) -def delete_dataset(dataset: DatasetHandler): +def delete_dataset(dataset_name: str, dataset_dir: Union[str, Path]): """ Delete a dataset from the user dataset config file and delete all the files Args: - dataset: a DatasetHandler object + dataset_name: string with the name of the dataset + dataset_dir: path to the dataset directory Returns: None """ - delist_dataset(dataset.name) - rmtree(dataset.dataset_dir) + delist_dataset(dataset_name) + rmtree(dataset_dir) + + +def change_dataset_location(dataset_name: str, new_location: Union[str, Path]): + """ + Change the location of a dataset in the user dataset config file and move all the files + Args: + dataset_name: string with the name of the dataset + new_location: new location of the dataset + + Returns: + None + """ + new_location = Path(new_location) + if not new_location.exists(): + raise ValueError(f'New location {new_location} does not exist.') + + assert new_location.is_dir(), f'New location {new_location} is not a directory.' + dataset_dict = get_dataset_dict(dataset_name) + dataset_dict['dataset_dir'] = str(new_location) + dataset = DatasetHandler.from_dict(dataset_dict) + save_dataset(dataset) diff --git a/plantseg/dataset_tools/images.py b/plantseg/dataset_tools/images.py index e6f3532c..53962832 100644 --- a/plantseg/dataset_tools/images.py +++ b/plantseg/dataset_tools/images.py @@ -1,40 +1,131 @@ +from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Union import h5py import numpy as np -from plantseg.io.h5 import create_h5, load_h5 +from plantseg.io.h5 import create_h5, load_h5, read_attribute_h5, write_attribute_h5 +from plantseg.io.h5 import list_keys as list_keys_h5 + + +class MockData: + ndim: int + shape: tuple[int] + key: str + path: Path + + def __init__(self, path: Path, key: str): + assert path.exists(), f'Path does not exist: {path}' + with h5py.File(path, mode='r') as f: + assert key in f, f'Key not found in file: {key}' + data = f[key] + self.ndim = data.ndim + self.shape = data.shape + + self.key = key + self.path = path + + def load(self): + return load_h5(self.path, key=self.key) + + +@dataclass +class ImageSpecs: + key: str + data_type: str + dimensionality: str + num_channels: int = 1 + is_sparse: bool = None + + def __post_init__(self): + assert self.data_type in ('image', 'labels'), f'data_type must be either image or label, found {self.data_type}' + assert self.dimensionality in ('2D', '3D'), (f'dimensionality must be either' + f' 2D or 3D, found {self.dimensionality}') + assert isinstance(self.num_channels, int), f'num_channels must be an integer, found {self.num_channels}' + assert self.num_channels > 0, f'num_channels must be greater than 0, found {self.num_channels}' + assert isinstance(self.key, str), f'key must be a string, found {self.key}' + + def to_dict(self): + return asdict(self) + + @classmethod + def from_h5(cls, path: Union[str, Path], key: str): + attrs = read_attribute_h5(path=path, key=key) + attrs = {k: v for k, v in attrs.items() if k in cls.__annotations__.keys()} + attrs['num_channels'] = int(attrs['num_channels']) + + if attrs['is_sparse'] is not None: + attrs['is_sparse'] = bool(attrs['is_sparse']) + return cls(**attrs) + + +@dataclass +class StackSpecs: + dimensionality: str = '3D' + list_specs: list[ImageSpecs] = field(default_factory=list) + + def __post_init__(self): + assert self.dimensionality in ('2D', '3D'), (f'dimensionality must be either' + f' 2D or 3D, found {self.dimensionality}') + for image in self.list_specs: + assert isinstance(image, ImageSpecs), f'list_images must contain ImageSpec objects, found {image}' + + @classmethod + def from_dict(cls, dict_specs): + list_specs = dict_specs.pop('list_specs') + list_specs = [ImageSpecs(**spec) for spec in list_specs] + return cls(list_specs=list_specs, **dict_specs) + + def to_dict(self): + return asdict(self) class GenericImage: dimensionality: str key: str num_channels: int - data: Union[np.ndarray, h5py.Dataset] + data: Union[np.ndarray, MockData] layout: str = 'xy' + is_sparse: bool + data_type: str def __init__(self, data: np.ndarray, - key: str = 'raw', - dimensionality: str = '3D'): + spec: ImageSpecs): """ Generic image class to handle 2D and 3D images consistently. Args: data (np.ndarray): image data - key (str): key to use when saving to h5 - dimensionality (str): '2D' or '3D' + spec (ImageSpecs): image specifications template to be used to create the image """ - assert dimensionality in ['2D', '3D'], f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.' self.data = data - self.key = key - self.dimensionality = dimensionality - self.num_channels = 1 + self.spec = spec def __repr__(self): return (f'{self.__class__.__name__}(dimensionality={self.dimensionality},' f' shape={self.shape},' f' layout={self.layout})') + @property + def key(self): + return self.spec.key + + @property + def is_sparse(self): + return self.spec.is_sparse + + @property + def data_type(self): + return self.spec.data_type + + @property + def dimensionality(self): + return self.spec.dimensionality + + @property + def num_channels(self): + return self.spec.num_channels + @property def ndim(self) -> int: return self.data.ndim @@ -48,16 +139,17 @@ def clean_shape(self) -> tuple[int, ...]: """ Returns the shape without singleton dimensions and channel dimensions """ - assert len(self.shape) == len(self.layout), f'Shape and layout do not match: {self.shape} vs {self.layout}.' clean_shape = [s for s, l in zip(self.shape, self.layout) if 'c' != l and '1' != l] return tuple(clean_shape) - def load_data(self): + def load_data(self) -> np.ndarray: """ Load the data from the h5 file """ - if isinstance(self.data, h5py.Dataset): - self.data = self.data[...] + if isinstance(self.data, MockData): + return self.data.load() + + return self.data def remove_singletons(self, remove_channel: bool = False): """ @@ -79,134 +171,211 @@ def remove_singletons(self, remove_channel: bool = False): if axis_to_squeeze: data = np.squeeze(data, axis=tuple(axis_to_squeeze)) - return type(self)(data=data, key=self.key, dimensionality=self.dimensionality) + spec = ImageSpecs(key=self.key, + data_type=self.data_type, + dimensionality=self.dimensionality, + num_channels=self.num_channels, + is_sparse=self.is_sparse) + return type(self)(data=data, spec=spec) return self + def check_compatibility(self, spec: ImageSpecs) -> tuple[bool, str]: + """ + Validate the image against a specification + Args: + spec (ImageSpecs): specification to validate against + Returns: + bool: True if the image is valid, False otherwise + str: error message + """ + if self.data_type != spec.data_type: + return False, f'Invalid data type: {self.data_type} vs {spec.data_type}' + + if self.dimensionality != spec.dimensionality: + return False, f'Invalid dimensionality: {self.dimensionality} vs {spec.dimensionality}' + + if self.num_channels != spec.num_channels: + return False, f'Invalid number of channels: {self.num_channels} vs {spec.num_channels}' + + if self.data_type == 'labels' and self.is_sparse != spec.is_sparse: + return False, f'Invalid sparsity: {self.is_sparse} vs {spec.is_sparse}' + + return True, '' + @classmethod - def from_h5(cls, path: Union[str, Path], key: str, dimensionality: str, load_data: bool = False): + def from_h5(cls, path: Union[str, Path], key: str = None, spec: ImageSpecs = None): """ Instantiate a GenericImage from a h5 file + Args: + path (str, Path): path to the h5 file + key (str): key of the image in the h5 file + spec (ImageSpecs): image specifications template to be used to create the image + Returns: + GenericImage: instance of GenericImage """ - if load_data: - data, _ = load_h5(path=path, key=key) - else: - data = h5py.File(path, mode='r')[key] - return cls(data=data, key=key, dimensionality=dimensionality) + assert key is not None or spec is not None, 'Either key or spec must be provided.' + + if spec is None: + spec = ImageSpecs.from_h5(path=path, key=key) + + data = MockData(path=path, key=spec.key) + return cls(data=data, spec=spec) def to_h5(self, path: Union[str, Path], - voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), mode: str = 'a'): """ Save the data to a h5 file Args: path (str, Path): path to the h5 file - voxel_size (tuple): voxel size mode (str): 'a' to append to an existing file, 'w' to overwrite an existing file """ - self.load_data() - create_h5(path, stack=self.data, key=self.key, voxel_size=voxel_size, mode=mode) + data = self.load_data() + create_h5(path, stack=data, key=self.key, mode=mode) + write_attribute_h5(path=path, key=self.key, atr_dict=self.spec.to_dict()) class Image(GenericImage): def __init__(self, data: np.ndarray, - key: str = 'raw', - dimensionality: str = '3D'): + spec: ImageSpecs): """ Args: data: numpy array - key: internal key of the dataset - dimensionality: 2D or 3D + spec: ImageSpecs containing the specifications of the image """ - super().__init__(data=data, key=key, dimensionality=dimensionality) + assert spec.data_type == 'image', f'Invalid data type: {spec.data_type}, expected image.' + assert spec.is_sparse is None, f'Invalid sparsity: {spec.is_sparse}, expected None for images.' + dimensionality = spec.dimensionality if dimensionality == '2D': if data.ndim == 2: - self.num_channels = 1 - self.layout = 'xy' + num_channels = 1 + layout = 'xy' elif data.ndim == 3: - self.num_channels = data.shape[0] - self.layout = 'cxy' + num_channels = data.shape[0] + layout = 'cxy' elif data.ndim == 4: - self.num_channels = data.shape[0] + num_channels = data.shape[0] assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' - self.layout = 'c1xy' + layout = 'c1xy' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') elif dimensionality == '3D': if data.ndim == 3: - self.num_channels = 1 - self.layout = 'xyz' + num_channels = 1 + layout = 'xyz' elif data.ndim == 4: - self.num_channels = data.shape[0] - self.layout = 'cxyz' - raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') + num_channels = data.shape[0] + layout = 'cxyz' + else: + raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') else: raise ValueError(f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.') + spec.num_channels = num_channels + spec.layout = layout + super().__init__(data=data, spec=spec) + class Labels(GenericImage): def __init__(self, data: np.ndarray, - key: str = 'labels', - dimensionality: str = '3D'): + spec: ImageSpecs): """ Args: data: numpy array - key: internal key of the dataset - dimensionality: 2D or 3D + spec: ImageSpecs containing the specifications of a label image """ - super().__init__(data=data, key=key, dimensionality=dimensionality) + assert spec.data_type == 'labels', f'Invalid data type: {spec.data_type}, expected labels.' + assert spec.is_sparse is not None, 'Sparse flag must be set for labels.' + assert spec.num_channels == 1, f'Invalid number of channels: {spec.num_channels}, expected 1.' + dimensionality = spec.dimensionality if dimensionality == '2D': if data.ndim == 2: - self.num_channels = 1 - self.layout = 'xy' + layout = 'xy' elif data.ndim == 3: assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' - self.layout = '1xy' + layout = '1xy' elif data.ndim == 4: assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' - self.layout = '11xy' + layout = '11xy' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') elif dimensionality == '3D': if data.ndim == 3: - self.num_channels = 1 - self.layout = 'xyz' + layout = 'xyz' elif data.ndim == 4: assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' - self.layout = '1xyz' + layout = '1xyz' + else: + raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') else: raise ValueError(f'Invalid dimensionality: {dimensionality}, valid values are: 2D, 3D.') + spec.layout = layout + super().__init__(data=data, spec=spec) + + +available_image_types = {'image': Image, 'labels': Labels} + + +def from_spec(data: np.ndarray, spec: ImageSpecs) -> GenericImage: + """ + Create an image from a specification + Args: + data: numpy array + spec: ImageSpecs containing the specifications of the image + Returns: + GenericImage: image + """ + data_type = spec.data_type + if data_type not in available_image_types: + raise ValueError(f'Invalid image type: {data_type}, valid values are: {available_image_types.keys()}') + return available_image_types[data_type](data=data, spec=spec) + + +def from_h5(path: Union[str, Path], key: str) -> GenericImage: + """ + Load an image from a h5 file + Args: + path (str, Path): path to the h5 file + key (str): key of the image in the h5 file + Returns: + GenericImage: image + """ + spec = ImageSpecs.from_h5(path=path, key=key) + image_type = available_image_types[spec.data_type] + data = image_type.from_h5(path=path, key=key) + return data + class Stack: dimensionality: str - layout: str - data: {} + list_specs: list[ImageSpecs] + data: dict[str, GenericImage] = {} - def __init__(self, *images: GenericImage, - dimensionality: str = '3D', + def __init__(self, *images: Union[GenericImage, np.ndarray], + spec: StackSpecs, strict: bool = True): """ Args: *images (GenericImage): list of images - dimensionality (str): 2D or 3D - strict (bool): if True, raise an error if the images do not have the same dimensionality + spec (StackSpecs): specification of the stack + strict (bool): if True, raise an error if the stack is invalid """ - self.dimensionality = dimensionality - + self.spec = spec + assert len(images) == len(self.list_specs), (f'Invalid number of images: {len(images)}, ' + f'expected {len(self.list_specs)}.') data = {} - dimensionality = images[0].dimensionality - for image in images: - assert image.dimensionality == dimensionality, (f'Invalid dimensionality: {image.dimensionality},' - f' all images must have the same dimensionality.') - data[image.key] = image + for image, spec in zip(images, self.list_specs): + if isinstance(image, np.ndarray): + image = from_spec(data=image, spec=spec) + data[spec.key] = image self.data = data result, msg = self.validate() @@ -214,6 +383,30 @@ def __init__(self, *images: GenericImage, if not result and strict: raise ValueError(msg) + @property + def dimensionality(self) -> str: + """ + Return the dimensionality of the stack + """ + return self.spec.dimensionality + + @property + def list_specs(self) -> list[ImageSpecs]: + """ + Return the list of image specifications + """ + return self.spec.list_specs + + @property + def is_sparse(self) -> bool: + """ + Check if the stack is sparse + """ + labels_is_sparse = [spec.is_sparse for spec in self.list_specs if spec.data_type == 'labels'] + if len(labels_is_sparse) == 0: + return False + return all(labels_is_sparse) + @property def keys(self) -> list[str]: """ @@ -226,9 +419,22 @@ def clean_shape(self) -> tuple[int, ...]: """ Return the shape of the stack without the channel dimension and singleton dimensions """ + if len(self.keys) == 0: + return () key = self.keys[0] return self.data[key].clean_shape + def validate_images(self, list_specs: list[ImageSpecs] = None) -> tuple[bool, str]: + if list_specs is None: + list_specs = self.list_specs + + for image, spec in zip(self.data.values(), list_specs): + result, msg = image.check_compatibility(spec) + if not result: + return False, f'Error in {image.key}: {msg}' + + return True, '' + def validate_dimensionality(self) -> tuple[bool, str]: for image in self.data.values(): if image.dimensionality != self.dimensionality: @@ -239,15 +445,28 @@ def validate_dimensionality(self) -> tuple[bool, str]: return True, '' def validate_layout(self) -> tuple[bool, str]: - for type_image in [Image, Labels]: - list_image = [image for image in self.data.values() if isinstance(image, type_image)] - if list_image: - layout = list_image[0].layout - for image in list_image: - if image.layout != layout: - msg = (f'Invalid layout: {image.layout}, all' - f' images of type {self.__class__.__name__} must have the same layout.') - return False, msg + # check image layout + images_layout, labels_layout = [], [] + for image, spec in zip(self.data.values(), self.list_specs): + if spec.data_type == 'image': + images_layout.append(image.layout) + elif spec.data_type == 'labels': + labels_layout.append(image.layout) + else: + raise ValueError(f'Invalid image type: {spec.data_type}, valid values are: image, labels.') + + if len(images_layout) > 0: + for layout in images_layout: + if layout != images_layout[0]: + msg = f'Invalid layout found in images: {images_layout}' + return False, msg + + if len(labels_layout) > 0: + for layout in labels_layout: + if layout != labels_layout[0]: + msg = f'Invalid layout found in labels: {labels_layout}' + return False, msg + return True, '' def validate_shape(self) -> tuple[bool, str]: @@ -266,12 +485,33 @@ def validate(self) -> tuple[bool, str]: """ Validate the stack to ensure that all images have the same dimensionality, layout and shape. """ - for test in [self.validate_dimensionality, self.validate_layout, self.validate_shape]: + for test in [self.validate_images, + self.validate_dimensionality, + self.validate_layout, + self.validate_shape]: result, msg = test() if not result: return False, msg return True, '' + def check_compatibility(self, stack_spec: StackSpecs) -> tuple[bool, str]: + + if self.dimensionality != stack_spec.dimensionality: + msg = (f'Invalid dimensionality: {self.dimensionality},' + f' expected {stack_spec.dimensionality}.') + return False, msg + + if len(self.list_specs) != len(stack_spec.list_specs): + msg = (f'Invalid number of images: {len(self.list_specs)},' + f' expected {len(stack_spec.list_specs)}.') + return False, msg + + result, msg = self.validate_images(stack_spec.list_specs) + if not result: + return False, msg + + return True, '' + def dump_to_h5(self, path: Union[str, Path], mode: str = 'a'): """ Dump the full stack to an HDF5 file. @@ -281,29 +521,61 @@ def dump_to_h5(self, path: Union[str, Path], mode: str = 'a'): """ assert mode in ['w', 'a', 'r+', 'w-'], f'Invalid mode: {mode}, valid values are: [w, a, r+, w-].' + + keys, image_types = [], [] for key, stack in self.data.items(): stack.to_h5(path=path, mode=mode) # switch to append mode after first iteration mode = 'a' + keys.append(key) + + for _name, _image_type in available_image_types.items(): + if isinstance(stack, _image_type): + image_types.append(_name) + + specs_dict = self.spec.to_dict() + specs_dict.pop('list_specs') + write_attribute_h5(path, atr_dict=specs_dict, key=None) @classmethod def from_h5(cls, path: Union[str, Path], - keys: tuple[tuple[str, GenericImage]], - dimensionality: str, - load_data: bool = False, + expected_stack_specs: StackSpecs = None, strict: bool = True): """ Load the full stack from an HDF5 file. Args: path: path to the HDF5 file - keys: list of (keys, type of data) to load - dimensionality: 2D or 3D - load_data: if True, load the data from the HDF5 file + expected_stack_specs: stack specifications strict: if True, raise an error if the images do not have the same dimensionality """ - data = [] - for key, type_image in keys: - im = type_image.from_h5(path=path, key=key, dimensionality=dimensionality, load_data=load_data) - data.append(im) + if expected_stack_specs is None: + stack_attrs = read_attribute_h5(path, key=None) + stack_attrs = {k: v for k, v in stack_attrs.items() if k in StackSpecs.__annotations__.keys()} + stack_attrs['list_specs'] = [] + stack_spec = StackSpecs.from_dict(stack_attrs) + list_keys = list_keys_h5(path) + else: + stack_spec = expected_stack_specs + list_keys = [s.key for s in stack_spec.list_specs] + + list_data = [] + list_specs_found = [] + for key in list_keys: + data = from_h5(path, key=key) + list_data.append(data) + list_specs_found.append(data.spec) + + stack_spec.list_specs = list_specs_found + stack = cls(*list_data, spec=stack_spec) + + if strict: + result, msg = stack.validate() + if not result: + return stack, result, msg + + if expected_stack_specs is not None: + result, msg = stack.check_compatibility(expected_stack_specs) + if not result: + return stack, result, msg - return cls(*data, dimensionality=dimensionality, strict=strict) + return stack, True, '' diff --git a/plantseg/dataset_tools/validators.py b/plantseg/dataset_tools/validators.py index 7c63e00a..388ba4d8 100644 --- a/plantseg/dataset_tools/validators.py +++ b/plantseg/dataset_tools/validators.py @@ -1,8 +1,4 @@ -from pathlib import Path -from typing import Union - from plantseg.dataset_tools.dataset_handler import DatasetHandler -from plantseg.io.h5 import list_keys class CheckDatasetDirectoryStructure: @@ -18,58 +14,3 @@ def __call__(self, dataset: DatasetHandler) -> tuple[bool, str]: return False, f'Dataset directory {dataset.dataset_dir} does not contain {phase} directory.' return True, '' - - -class CheckH5Keys: - def __init__(self, expected_h5_keys: tuple[str, ...] = ('raw', 'labels')): - self.expected_h5_keys = expected_h5_keys - - def __call__(self, stack: Union[str, Path]) -> tuple[bool, str]: - found_keys = list_keys(stack) - for key in self.expected_h5_keys: - if key not in found_keys: - return False, f'Key {key} not found in {stack}. Expected keys: {self.expected_h5_keys}' - - return True, '' - - -class CheckH5shapes: - def __init__(self, dimensionality: str = '3D', - expected_h5_keys: tuple[str, ...] = (('raw', 'image'), - ('labels', 'labels') - )): - """ - Check if the shape of the data in the h5 file matches the expected shape. - Args: - dimensionality: '2D' or '3D' - expected_h5_keys: tuple of tuples, each tuple contains the key and the expected type of data - possible types are: 'image', 'labels' - """ - assert dimensionality in ['2D', '3D'], f'Invalid dimensionality: {dimensionality}, ' \ - f'valid values are: 2D, 3D' - - self.expected_shapes = {} - if dimensionality == '2D': - for key, data_type in expected_h5_keys: - assert data_type in ['image', 'labels'], f'Invalid data type: {data_type}, ' \ - f'valid values are: image, labels' - if data_type == 'image': - self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, - {'ndim': 3, 'shape': 'cxy'}, - {'ndim': 4, 'shape': 'c1xy'}, - {'ndim': 4, 'shape': '1xy'}] - elif data_type == 'labels': - self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, - {'ndim': 3, 'shape': '1xy'}] - elif dimensionality == '3D': - for key, data_type in expected_h5_keys: - assert data_type in ['image', 'labels'], f'Invalid data type: {data_type}, ' \ - f'valid values are: image, labels' - if data_type == 'image': - self.expected_shapes[key] = [{'ndim': 3, 'shape': 'zxy'}, - {'ndim': 4, 'shape': 'czxy'}, - {'ndim': 4, 'shape': '1xy'}, - {'ndim': 4, 'shape': '1xy'}] - elif data_type == 'labels': - self.expected_shapes[key] = [{'ndim': 2, 'shape': 'xy'}, - {'ndim': 3, 'shape': '1xy'}] diff --git a/plantseg/io/h5.py b/plantseg/io/h5.py index 8fe4dbd7..e4332617 100644 --- a/plantseg/io/h5.py +++ b/plantseg/io/h5.py @@ -1,7 +1,7 @@ import warnings +from pathlib import Path from typing import Optional, Union -from pathlib import Path import h5py import numpy as np @@ -55,6 +55,7 @@ def visitor_func(name, node): def load_h5(path: Union[str, Path], key: str, slices: Optional[slice] = None, + info_only: bool = False) -> Union[tuple, tuple[np.array, tuple]]: """ Load a dataset from a h5 file and returns some meta info about it. @@ -110,6 +111,65 @@ def create_h5(path: Union[str, Path], f[key].attrs['element_size_um'] = voxel_size +def write_attribute_h5(path: Union[str, Path], atr_dict: dict, key: str = None) -> None: + """ + Helper function to add attributes to a h5 file + Args: + path (str): file path + atr_dict (dict): dictionary of attributes to add + key (str): key of the dataset in the h5 file + + Returns: + None + """ + assert Path(path).suffix in H5_EXTENSIONS, f"File {path} is not a h5 file" + assert Path(path).exists(), f"File {path} does not exist" + assert isinstance(atr_dict, dict), "atr_dict must be a dictionary" + assert isinstance(key, str) or key is None, "key must be a string or None" + + with h5py.File(path, mode='r+') as f: + if key is None: + file = f + elif key in f: + file = f[key] + else: + raise KeyError(f"Key {key} not found in {path}") + + for k, v in atr_dict.items(): + if v is None: + v = 'none' + file.attrs[k] = v + + +def read_attribute_h5(path: Union[str, Path], key: str = None) -> dict: + """ + Helper function to read attributes from a h5 file + Args: + path (str): file path + key (str): key of the dataset in the h5 file + + Returns: + dict: dictionary of attributes + """ + assert Path(path).suffix in H5_EXTENSIONS, f"File {path} is not a h5 file" + assert Path(path).exists(), f"File {path} does not exist" + assert isinstance(key, str) or key is None, "key must be a string or None" + with h5py.File(path, mode='r') as f: + if key is None: + attrs = f.attrs + elif key in f: + attrs = f[key].attrs + else: + raise KeyError(f"Key {key} not found in {path}") + + attrs_dict = {} + for k, v in attrs.items(): + if isinstance(v, str) and v == 'none': + v = None + attrs_dict[k] = v + return attrs_dict + + def list_keys(path: Union[str, Path]) -> list[str]: """ List all keys in a h5 file @@ -119,6 +179,7 @@ def list_keys(path: Union[str, Path]) -> list[str]: Returns: list of keys """ + def _recursive_find_keys(f, base='/'): _list_keys = [] for key, dataset in f.items(): diff --git a/plantseg/ui/containers.py b/plantseg/ui/containers.py index b6c0735c..e6ed0161 100644 --- a/plantseg/ui/containers.py +++ b/plantseg/ui/containers.py @@ -15,8 +15,7 @@ from plantseg.ui.widgets.segmentation import widget_fix_over_under_segmentation_from_nuclei from plantseg.ui.widgets.segmentation import widget_lifted_multicut from plantseg.ui.widgets.segmentation import widget_simple_dt_ws -from plantseg.ui.widgets.dataset_tools import widget_create_dataset -from plantseg.ui.widgets.dataset_tools import widget_add_stack, widget_delete_dataset, widget_validata_dataset +from plantseg.ui.widgets.dataset_tools import widget_create_dataset, widget_edit_dataset def setup_menu(container, path=None): @@ -66,9 +65,8 @@ def get_gasp_workflow(): def get_dataset_workflow(): container = MainWindow(widgets=[widget_create_dataset, - widget_add_stack, - widget_validata_dataset, - widget_delete_dataset], + widget_edit_dataset + ], labels=False) container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Dataset-Managment') return container diff --git a/plantseg/ui/widgets/dataset_tools.py b/plantseg/ui/widgets/dataset_tools.py index 35ae6f02..64880307 100644 --- a/plantseg/ui/widgets/dataset_tools.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -1,50 +1,135 @@ +from enum import Enum from pathlib import Path +from typing import Optional from magicgui import magicgui -from napari.layers import Labels, Image +from napari.layers import Labels, Image, Layer from plantseg import PLANTSEG_MODELS_DIR -from plantseg.io import create_h5 -from plantseg.utils import list_datasets, dump_dataset_dict, get_dataset_dict, delist_dataset +from plantseg.dataset_tools.dataset_handler import DatasetHandler, save_dataset, load_dataset +from plantseg.dataset_tools.images import Image as PlantSegImage +from plantseg.dataset_tools.images import Labels as PlantSegLabels +from plantseg.dataset_tools.images import Stack, ImageSpecs, StackSpecs from plantseg.ui.logging import napari_formatted_logging +from plantseg.utils import list_datasets empty_dataset = ['none'] startup_list_datasets = list_datasets() or empty_dataset +class ImageType(Enum): + IMAGE: str = 'image' + LABELS: str = 'labels' + + @magicgui(call_button='Initialize Dataset', dataset_name={'label': 'Dataset name', 'tooltip': f'Initialize an empty dataset with name model_name'}, dataset_dir={'label': 'Path to the dataset directory', 'mode': 'd', 'tooltip': 'Select a directory containing where the dataset will be created, ' - '{dataset_dir}/model_name/.'} + '{dataset_dir}/model_name/.'}, + dimensionality={'label': 'Dimensionality', + 'choices': ['2D', '3D'], + 'tooltip': f'Initialize an empty dataset with name model_name'}, + images_format={'label': 'Expected images format \n (Name, Channel, Type)', + 'layout': 'vertical', + 'tooltip': f'Initialize an empty dataset with name model_name'}, + is_sparse={'label': 'Sparse dataset', + 'tooltip': 'If checked, the dataset will be saved in sparse format.'}, ) -def widget_create_dataset(dataset_name: str = 'my-dataset', dataset_dir: Path = Path.home()): +def widget_create_dataset(dataset_name: str = 'my-dataset', + dataset_dir: Path = Path.home(), + dimensionality: str = '2D', + images_format: list[tuple[str, int, ImageType]] = (('raw', 1, ImageType.IMAGE), + ('labels', 1, ImageType.LABELS)), + is_sparse: bool = False): + if dataset_name in list_datasets(): + napari_formatted_logging(message='Dataset already exists.', thread='widget_create_dataset', level='warning') + return None + + list_images = [] + for key, num_channels, im_format in images_format: + + if im_format == ImageType.IMAGE: + image_spec = ImageSpecs(key=key, + num_channels=num_channels, + dimensionality=dimensionality, + data_type='image') + list_images.append(image_spec) + elif im_format == ImageType.LABELS: + assert num_channels == 1, 'Labels must have only one channel.' + labels_spec = ImageSpecs(key=key, + num_channels=1, + dimensionality=dimensionality, + data_type='labels', + is_sparse=is_sparse) + list_images.append(labels_spec) + + else: + raise ValueError(f'Image format {im_format} not supported.') + dataset_dir = dataset_dir / dataset_name - dataset_dir.mkdir(parents=True, exist_ok=True) - new_dataset = {'name': dataset_name, - 'dataset_dir': str(dataset_dir), - 'task': None, - 'dimensionality': None, # 2D or 3D - 'image_channels': None, - 'image_key': 'raw', - 'labels_key': 'labels', - 'is_sparse': False, - 'train': [], - 'val': [], - 'test': [], - } + stack_specs = StackSpecs(dimensionality=dimensionality, + list_specs=list_images) + + dataset = DatasetHandler(name=dataset_name, + dataset_dir=dataset_dir, + expected_stack_specs=stack_specs) + + save_dataset(dataset) + return dataset + + +def _add_stack(dataset_name: str = startup_list_datasets[0], + images: list[tuple[str, Optional[Layer]]] = (), + phase: str = 'train', + is_sparse: bool = False, + **kwargs): + dataset = load_dataset(dataset_name) + image_specs = dataset.expected_stack_specs.list_specs + stack_specs = dataset.expected_stack_specs + + list_images = [] + for image_name, layer in images: + reference_spec = [spec for spec in image_specs if spec.key == image_name][0] + if isinstance(layer, Image) and reference_spec.data_type == 'image': + image_data = layer.data + image = PlantSegImage(image_data, spec=reference_spec) + elif isinstance(layer, Labels) and reference_spec.data_type == 'labels': + labels_data = layer.data + reference_spec.is_sparse = is_sparse + image = PlantSegLabels(labels_data, spec=reference_spec) + else: + raise ValueError(f'Layer type {type(layer)} not supported.') + + list_images.append(image) + + stack_name = images[0][1].name + stack = Stack(*list_images, spec=stack_specs) + dataset.add_stack(stack_name=stack_name, stack=stack, phase=phase) - if dataset_name not in list_datasets(): - dump_dataset_dict(dataset_name, new_dataset) - return new_dataset - raise ValueError(f'Dataset {dataset_name} already exists.') +def _remove_stack(dataset_name, stack_name: str, **kwargs): + dataset = load_dataset(dataset_name) + dataset.remove_stack(stack_name) -@magicgui(call_button='Create Dataset', +available_modes = { + 'Add stack to dataset': _add_stack, + 'Remove stack from dataset': _remove_stack, + 'Remove dataset': None, + 'De-list dataset': None, + 'Move dataset': None, + 'Rename dataset': None +} + + +@magicgui(call_button='Edit Dataset', + action={'label': 'Action', + 'choices': list(available_modes.keys()), + 'tooltip': f'Define if the stack will be added or removed from the dataset'}, dataset_name={'label': 'Dataset name', 'choices': startup_list_datasets, 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, @@ -52,150 +137,112 @@ def widget_create_dataset(dataset_name: str = 'my-dataset', dataset_dir: Path = 'choices': ['train', 'val', 'test'], 'tooltip': f'Define if the stack will be used for training, validation or testing'}, is_sparse={'label': 'Sparse dataset', - 'tooltip': 'If checked, the dataset will be saved in sparse format.'} + 'tooltip': 'If checked, the dataset will be saved in sparse format.'}, + stack_name={'label': 'Stack name', + 'tooltip': f'Name of the stack to be added to be edited', + 'choices': ['', ]}, ) -def widget_add_stack(dataset_name: str = startup_list_datasets[0], - image: Image = None, - labels: Labels = None, - phase: str = 'train', - is_sparse: bool = False): - dataset_config = get_dataset_dict(dataset_name) - - if image is None or labels is None: - napari_formatted_logging(message=f'To add a stack to the dataset, please select an image and a labels layer.', - thread='widget_add_stack', - level='warning') +def widget_edit_dataset(action: str = list(available_modes.keys())[0], + dataset_name: str = startup_list_datasets[0], + images: list[tuple[str, Optional[Layer]]] = (), + phase: str = 'train', + is_sparse: bool = False, + stack_name: str = '') -> str: + func = available_modes[action] + kwargs = {'dataset_name': dataset_name, + 'images': images, + 'phase': phase, + 'is_sparse': is_sparse, + 'stack_name': stack_name} + func(**kwargs) + return action + + +@widget_create_dataset.called.connect +def update_dataset_name(dataset: DatasetHandler): + if dataset is None: return None - if is_sparse: - # if a single dataset is sparse, all the others should be threaded as sparse - dataset_config['is_sparse'] = True - - image_data = image.data - labels_data = labels.data - - # Validation of the image and labels data - # check if the image and labels have the same shape, - # dimensionality and number of channels as the rest of the dataset - - if image_data.ndim == 3: - image_channels = 1 - dimensionality = '2D' if image_data.shape[0] == 1 else '3D' - assert image_data.shape == labels_data.shape, f'Image and labels should have the same shape, found ' \ - f'{image_data.shape} and {labels_data.shape}.' - - elif image_data.ndim == 4: - image_channels = image_data.shape[0] - dimensionality = '2D' if image_data.shape[1] == 1 else '3D' - assert image_data.shape[1:] == labels_data.shape, f'Image and labels should have the same shape, found ' \ - f'{image_data.shape} and {labels_data.shape}.' - - else: - raise ValueError(f'Image data should be 3D or multichannel 3D, found {image_data.ndim}D.') - - dataset_image_channels = dataset_config['image_channels'] - if dataset_image_channels is None: - dataset_config['image_channels'] = image_channels - elif dataset_image_channels != image_channels: - raise ValueError(f'Image data should have {dataset_image_channels} channels, found {image_channels}.') - - dataset_dimensionality = dataset_config['dimensionality'] - if dataset_dimensionality is None: - dataset_config['dimensionality'] = dimensionality - elif dataset_dimensionality != dimensionality: - raise ValueError(f'Image data should be {dataset_dimensionality}, found {dimensionality}.') - - if is_sparse: - dataset_config['is_sparse'] = True - - # Check if the stack name already exists in the dataset - # If so, add a number to the end of the name until it is unique - stack_name = image.name - existing_stacks = dataset_config[phase] - - idx = 0 - while True: - if stack_name in existing_stacks: - stack_name = f'{stack_name}_{idx}' - else: - break - idx += 1 + widget_edit_dataset.dataset_name.choices = list_datasets() + widget_edit_dataset.dataset_name.value = dataset.name - dataset_config[phase].append(stack_name) - # Save the data to disk - dataset_dir = Path(dataset_config['dataset_dir']) / phase - dataset_dir.mkdir(parents=True, exist_ok=True) +def _update_stack_name_choices(dataset: DatasetHandler = None): + if dataset is None: + dataset = load_dataset(widget_edit_dataset.dataset_name.value) + stacks_options = dataset.find_stacks_names() + stacks_options = stacks_options if stacks_options else [''] + widget_edit_dataset.stack_name.choices = stacks_options + widget_edit_dataset.stack_name.value = stacks_options[0] if stacks_options else '' - image_path = str(dataset_dir / f'{stack_name}.h5') - create_h5(image_path, image_data, key=dataset_config['image_key']) - create_h5(image_path, labels_data, key=dataset_config['labels_key']) - dump_dataset_dict(dataset_name, dataset_config) - napari_formatted_logging(message=f'Stack {stack_name} added to dataset {dataset_name}.', - thread='widget_add_stack', - level='info') +def _update_images_choices(dataset: DatasetHandler = None): + if len(list_datasets()) == 0: + return None -@magicgui(call_button='Validata Dataset', - dataset_name={'label': 'Dataset name', - 'choices': startup_list_datasets, - 'tooltip': f'Name of the dataset to be validated'}, - ) -def widget_validata_dataset(dataset_name: str = startup_list_datasets[0]): - dataset_config = get_dataset_dict(dataset_name) - - # check all stacks are present - dataset_dir = Path(dataset_config['dataset_dir']) - for phase in ['train', 'val', 'test']: - phase_dir = dataset_dir / phase - stacks_expected = dataset_config[phase] - stacks_found = [file.stem for file in phase_dir.glob('*.h5')] - if len(stacks_found) != len(stacks_expected): - napari_formatted_logging(message=f'Found {len(stacks_found)} stacks in {phase} phase, ' - f'expected {len(stacks_expected)}.', - thread='widget_validata_dataset', - level='warning') - - dataset_config[phase] = stacks_found - - # check all stacks have the same shape and dimensionality - for key, value in dataset_config.items(): - napari_formatted_logging(message=f'Dataset info {key}: {value}', - thread='widget_validata_dataset', - level='info') - - -@magicgui(call_button='Delete Dataset', - dataset_name={'label': 'Dataset name', - 'choices': startup_list_datasets, - 'tooltip': f'Name of the dataset to be deleted'}, - ) -def widget_delete_dataset(dataset_name: str = startup_list_datasets[0]): - delist_dataset(dataset_name) + if dataset is None: + dataset = load_dataset(widget_edit_dataset.dataset_name.value) + images_default = [] + for image in dataset.expected_stack_specs.list_specs: + images_default.append((image.key, None)) -@widget_create_dataset.called.connect -def _on_create_dataset_called(new_dataset: dict): - new_dataset_list = list_datasets() - if not widget_add_stack.visible: - widget_add_stack.show() - widget_add_stack.dataset_name.choices = new_dataset_list - widget_add_stack.dataset_name.value = new_dataset['name'] + widget_edit_dataset.images.value = images_default + + +_update_images_choices() + + +@widget_edit_dataset.dataset_name.changed.connect +def update_dataset_name(dataset_name: str): + dataset = load_dataset(dataset_name) + widget_edit_dataset.phase.choices = dataset.default_phases + widget_edit_dataset.is_sparse.value = dataset.is_sparse + + _update_images_choices(dataset) + _update_stack_name_choices(dataset) + + +def _add_stack_callback(): + widget_edit_dataset.images.show() + widget_edit_dataset.phase.show() + widget_edit_dataset.stack_name.hide() + + +def _remove_stack_callback(): + widget_edit_dataset.images.hide() + widget_edit_dataset.phase.hide() + widget_edit_dataset.is_sparse.hide() + widget_edit_dataset.stack_name.show() + + _update_stack_name_choices() + + +_add_stack_callback() + + +available_modes_callbacks = { + 'Add stack to dataset': _add_stack_callback, + 'Remove stack from dataset': _remove_stack_callback, +} + + +@widget_edit_dataset.action.changed.connect +def update_mode(action: str): + if action in available_modes_callbacks.keys(): + available_modes_callbacks[action]() - if not widget_delete_dataset.visible: - widget_delete_dataset.show() - widget_delete_dataset.dataset_name.choices = new_dataset_list - widget_delete_dataset.dataset_name.value = new_dataset['name'] +def _remove_stack_update_choices(): + _update_stack_name_choices() - if not widget_validata_dataset.visible: - widget_validata_dataset.show() - widget_validata_dataset.dataset_name.choices = new_dataset_list - widget_validata_dataset.dataset_name.value = new_dataset['name'] +available_actions_on_done = { + 'Remove stack from dataset': _remove_stack_update_choices +} -if startup_list_datasets == empty_dataset: - widget_add_stack.hide() - widget_delete_dataset.hide() - widget_validata_dataset.hide() +@widget_edit_dataset.called.connect +def update_state_after_edit(action: str): + if action in available_actions_on_done.keys(): + available_actions_on_done[action]() From d0b02fdb1f9180c389a2fe5724b757b0a696eb8f Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Aug 2023 16:22:45 +0200 Subject: [PATCH 18/22] fix bugs and polish gui --- plantseg/dataset_tools/dataset_handler.py | 59 ++- plantseg/dataset_tools/images.py | 29 +- plantseg/ui/widgets/dataset_tools.py | 523 ++++++++++++++++++---- 3 files changed, 502 insertions(+), 109 deletions(-) diff --git a/plantseg/dataset_tools/dataset_handler.py b/plantseg/dataset_tools/dataset_handler.py index 4ba47905..bed637b5 100644 --- a/plantseg/dataset_tools/dataset_handler.py +++ b/plantseg/dataset_tools/dataset_handler.py @@ -136,7 +136,8 @@ def info(self) -> str: info = f'{self.__repr__()}:\n' info += f'Dimensionality: {self.expected_stack_specs.dimensionality}\n' info += f'Is sparse: {self.is_sparse}\n' - info += f'Number of stacks: {len(self.find_stacks_names())}\n' + info += f'Num of stacks: {len(self.find_stacks_names())} (train: {len(self.train)}, val: {len(self.val)}, ' \ + f'test: {len(self.test)}) \n' return info def validate(self, *dataset_validators: DatasetValidator) -> tuple[bool, str]: @@ -157,6 +158,13 @@ def get_stack(self, path: Union[str, Path]) -> tuple[Stack, bool, str]: """ return Stack.from_h5(path=path, expected_stack_specs=self.expected_stack_specs) + def get_stack_from_name(self, stack_name: str) -> tuple[Stack, bool, str]: + for h5 in self.find_stored_files(): + if stack_name == f'{h5.parent.name}/{h5.stem}': + return self.get_stack(h5) + + raise ValueError(f'Stack {stack_name} not found in dataset {self.name}') + def update_stack_from_disk(self, phase: str = None): """ Update the stacks in the dataset from the disk. @@ -196,13 +204,16 @@ def find_stored_files(self, phase: str = None, ignore_default_file_format: bool file_formats = self.default_file_formats if not ignore_default_file_format else ('*',) for phase in phases: + phase_found_files = [] phase_dir = self.dataset_dir / phase assert phase_dir.exists(), f'Phase {phase} not found in {self.dataset_dir}' for file_format in file_formats: stacks_found = [file for file in phase_dir.glob(f'*{file_format}')] - found_files.extend(stacks_found) + phase_found_files.extend(stacks_found) + phase_found_files = sorted(phase_found_files, key=lambda x: x.stem) + found_files.extend(phase_found_files) return found_files def find_stacks_names(self, phase: str = None) -> list[str]: @@ -210,7 +221,7 @@ def find_stacks_names(self, phase: str = None) -> list[str]: Find the name of the stacks in the dataset directory. """ stacks = self.find_stored_files(phase=phase) - return [stack.stem for stack in stacks] + return [f'{stack.parent.name}/{stack.stem}' for stack in stacks] def get_stacks(self, phase: str = None) -> list[Stack]: """ @@ -246,8 +257,15 @@ def add_stack(self, stack_name: str, phase_dir = self.dataset_dir / phase stack_path = phase_dir / f'{stack_name}.h5' idx = 1 + while stack_path.exists() and unique_name: - stack_name += f'_{idx}' + if stack_name.find('_') == -1: + stack_name += f'_{idx}' + else: + *name_base, idx_name = stack_name.split('_') + name_base = '_'.join(name_base) + stack_name = f'{name_base}_{idx}' + stack_path = phase_dir / f'{stack_name}.h5' idx += 1 @@ -269,7 +287,7 @@ def remove_stack(self, stack_name: str): for phase in self.default_phases: stacks = self.find_stacks_names(phase=phase) if stack_name in stacks: - stack_path = self.dataset_dir / phase / f'{stack_name}.h5' + stack_path = self.dataset_dir / f'{stack_name}.h5' if stack_path.exists(): stack_path.unlink() self.update_stack_from_disk(phase=phase) @@ -291,7 +309,7 @@ def rename_stack(self, stack_name: str, new_name: str): for phase in self.default_phases: stacks = self.find_stacks_names(phase=phase) if stack_name in stacks: - stack_path = self.dataset_dir / phase / f'{stack_name}.h5' + stack_path = self.dataset_dir / f'{stack_name}.h5' if stack_path.exists(): new_stack_path = self.dataset_dir / phase / f'{new_name}.h5' stack_path.rename(new_stack_path) @@ -302,6 +320,35 @@ def rename_stack(self, stack_name: str, new_name: str): raise ValueError(f'Stack {stack_name} not found in dataset {self.name}.') + def change_phase_stack(self, stack_name: str, new_phase: str): + """ + Change the phase of a stack in the dataset. + Args: + stack_name: string with the name of the stack + new_phase: string with the new phase of the stack + + Returns: None + """ + assert new_phase in self.default_phases, f'Phase {new_phase} not found in dataset {self.name}.' + for phase in self.default_phases: + stacks = self.find_stacks_names(phase=phase) + if stack_name in stacks: + stack_path = self.dataset_dir / f'{stack_name}.h5' + if stack_path.exists(): + if phase == new_phase: + return None + + stack_name = stack_name.split('/')[-1] + new_stack_path = self.dataset_dir / new_phase / f'{stack_name}.h5' + stack_path.rename(new_stack_path) + self.update_stack_from_disk() + return None + else: + raise FileNotFoundError(f'Stack {stack_name} not found in {phase} phase.') + + raise ValueError(f'Stack {stack_name} not found in dataset {self.name}.') + + def load_dataset(dataset_name: str) -> DatasetHandler: """ diff --git a/plantseg/dataset_tools/images.py b/plantseg/dataset_tools/images.py index 53962832..eb2480e3 100644 --- a/plantseg/dataset_tools/images.py +++ b/plantseg/dataset_tools/images.py @@ -27,7 +27,8 @@ def __init__(self, path: Path, key: str): self.path = path def load(self): - return load_h5(self.path, key=self.key) + data, infos = load_h5(self.path, key=self.key) + return data, infos @dataclass @@ -89,6 +90,7 @@ class GenericImage: layout: str = 'xy' is_sparse: bool data_type: str + infos: tuple = None def __init__(self, data: np.ndarray, spec: ImageSpecs): @@ -147,7 +149,9 @@ def load_data(self) -> np.ndarray: Load the data from the h5 file """ if isinstance(self.data, MockData): - return self.data.load() + data, infos = self.data.load() + self.infos = infos + return data return self.data @@ -253,10 +257,15 @@ def __init__(self, data: np.ndarray, layout = 'xy' elif data.ndim == 3: num_channels = data.shape[0] + assert num_channels == spec.num_channels, (f'Invalid shape for 2D image: expected number of channels ' + f'{spec.num_channels}, got {num_channels}') layout = 'cxy' elif data.ndim == 4: num_channels = data.shape[0] - assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' + assert num_channels == spec.num_channels, (f'Invalid shape for 2D image: expected number of channels ' + f'{spec.num_channels}, got {num_channels}') + assert data.shape[1] == 1, (f'Invalid shape for 2D image: {data.shape}, expected (C, 1, X, Y), ' + f'got (c, {data.shape[1]}, x, y') layout = 'c1xy' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') @@ -267,6 +276,8 @@ def __init__(self, data: np.ndarray, layout = 'xyz' elif data.ndim == 4: num_channels = data.shape[0] + assert num_channels == spec.num_channels, (f'Invalid shape for 3D image: expected number of channels ' + f'{spec.num_channels}, got {num_channels}') layout = 'cxyz' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') @@ -296,11 +307,13 @@ def __init__(self, data: np.ndarray, if data.ndim == 2: layout = 'xy' elif data.ndim == 3: - assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' + assert data.shape[0] == 1, (f'Invalid shape for 2D labels. ' + f'Expected shape: (1, y, x) got ({data.shape[0]}, y, x).') layout = '1xy' elif data.ndim == 4: - assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' - assert data.shape[1] == 1, f'Invalid number of channels: {data.shape[1]}, expected 1.' + assert data.shape[0] == 1 and data.shape[1], (f'Invalid shape for 2D labels. ' + f'Expected shape: (1, 1, y, x) got' + f' ({data.shape[0]}, {data.shape[1]} y, x).') layout = '11xy' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 2 or 3 or 4.') @@ -309,7 +322,9 @@ def __init__(self, data: np.ndarray, if data.ndim == 3: layout = 'xyz' elif data.ndim == 4: - assert data.shape[0] == 1, f'Invalid number of channels: {data.shape[0]}, expected 1.' + assert data.shape[0] == 1, (f'Invalid shape for 3D labels. ' + f'Expected shape: (1, z, y, x) got' + f' ({data.shape[0]}, z, y, x).') layout = '1xyz' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') diff --git a/plantseg/ui/widgets/dataset_tools.py b/plantseg/ui/widgets/dataset_tools.py index 64880307..df9832ea 100644 --- a/plantseg/ui/widgets/dataset_tools.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -1,17 +1,20 @@ from enum import Enum from pathlib import Path from typing import Optional +from typing import Protocol +import napari from magicgui import magicgui from napari.layers import Labels, Image, Layer from plantseg import PLANTSEG_MODELS_DIR from plantseg.dataset_tools.dataset_handler import DatasetHandler, save_dataset, load_dataset +from plantseg.dataset_tools.dataset_handler import delete_dataset, change_dataset_location from plantseg.dataset_tools.images import Image as PlantSegImage from plantseg.dataset_tools.images import Labels as PlantSegLabels from plantseg.dataset_tools.images import Stack, ImageSpecs, StackSpecs from plantseg.ui.logging import napari_formatted_logging -from plantseg.utils import list_datasets +from plantseg.utils import list_datasets, get_dataset_dict, delist_dataset empty_dataset = ['none'] startup_list_datasets = list_datasets() or empty_dataset @@ -34,9 +37,9 @@ class ImageType(Enum): 'tooltip': f'Initialize an empty dataset with name model_name'}, images_format={'label': 'Expected images format \n (Name, Channel, Type)', 'layout': 'vertical', - 'tooltip': f'Initialize an empty dataset with name model_name'}, - is_sparse={'label': 'Sparse dataset', - 'tooltip': 'If checked, the dataset will be saved in sparse format.'}, + 'tooltip': f'Define the expected images format for the dataset.\n'}, + is_sparse={'label': 'Is Dataset sparse?', + 'tooltip': 'If checked, this info will be saved for training.'}, ) def widget_create_dataset(dataset_name: str = 'my-dataset', dataset_dir: Path = Path.home(), @@ -82,53 +85,256 @@ def widget_create_dataset(dataset_name: str = 'my-dataset', return dataset -def _add_stack(dataset_name: str = startup_list_datasets[0], - images: list[tuple[str, Optional[Layer]]] = (), - phase: str = 'train', - is_sparse: bool = False, - **kwargs): - dataset = load_dataset(dataset_name) - image_specs = dataset.expected_stack_specs.list_specs - stack_specs = dataset.expected_stack_specs - - list_images = [] - for image_name, layer in images: - reference_spec = [spec for spec in image_specs if spec.key == image_name][0] - if isinstance(layer, Image) and reference_spec.data_type == 'image': - image_data = layer.data - image = PlantSegImage(image_data, spec=reference_spec) - elif isinstance(layer, Labels) and reference_spec.data_type == 'labels': - labels_data = layer.data - reference_spec.is_sparse = is_sparse - image = PlantSegLabels(labels_data, spec=reference_spec) - else: - raise ValueError(f'Layer type {type(layer)} not supported.') - - list_images.append(image) - - stack_name = images[0][1].name - stack = Stack(*list_images, spec=stack_specs) - dataset.add_stack(stack_name=stack_name, stack=stack, phase=phase) +class Action(Protocol): + name: str + + @staticmethod + def edit(**kwargs): + pass + @staticmethod + def on_action_changed(): + pass -def _remove_stack(dataset_name, stack_name: str, **kwargs): - dataset = load_dataset(dataset_name) - dataset.remove_stack(stack_name) + @staticmethod + def on_edit_done(): + pass + + +class SingletonAction: + _instance = None + name: str = 'Abstract action' + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) -available_modes = { - 'Add stack to dataset': _add_stack, - 'Remove stack from dataset': _remove_stack, - 'Remove dataset': None, - 'De-list dataset': None, - 'Move dataset': None, - 'Rename dataset': None -} - - -@magicgui(call_button='Edit Dataset', + return cls._instance + + @staticmethod + def edit(**kwargs): + pass + + @staticmethod + def on_action_changed(): + pass + + @staticmethod + def on_edit_done(): + pass + + +class AddStack(SingletonAction): + name = 'Add stack to dataset' + + @staticmethod + def edit(dataset_name: str = startup_list_datasets[0], + images: list[tuple[str, Optional[Layer]]] = (), + new_stack_name: str = None, + phase: str = 'train', + is_sparse: bool = False, + strict: bool = False, + **kwargs): + dataset = load_dataset(dataset_name) + image_specs = dataset.expected_stack_specs.list_specs + stack_specs = dataset.expected_stack_specs + + list_images = [] + for image_name, layer in images: + reference_spec = [spec for spec in image_specs if spec.key == image_name][0] + + if isinstance(layer, Image) and reference_spec.data_type == 'image': + image_data = layer.data + image = PlantSegImage(image_data, spec=reference_spec) + elif isinstance(layer, Labels) and reference_spec.data_type == 'labels': + labels_data = layer.data + reference_spec.is_sparse = is_sparse + image = PlantSegLabels(labels_data, spec=reference_spec) + else: + napari_formatted_logging(message=f'Layer {image_name} is not of the expected type.', + thread='AddStack', + level='warning') + return None + + list_images.append(image) + + stack_name = new_stack_name if new_stack_name != '' else images[0][1].name + stack = Stack(*list_images, spec=stack_specs) + dataset.add_stack(stack_name=stack_name, stack=stack, phase=phase) + + +class RemoveStack(SingletonAction): + name = 'Remove stack from dataset' + + @staticmethod + def edit(dataset_name: str, stack_name: str, **kwargs): + dataset = load_dataset(dataset_name) + if stack_name not in dataset.find_stacks_names(): + napari_formatted_logging(message=f'Stack {stack_name} does not exist.', + thread='RemoveStack', + level='warning') + return None + dataset.remove_stack(stack_name) + + +class RenameStack(SingletonAction): + name = 'Rename stack from dataset' + + @staticmethod + def edit(dataset_name: str, stack_name: str, new_stack_name: str, **kwargs): + dataset = load_dataset(dataset_name) + if stack_name not in dataset.find_stacks_names(): + napari_formatted_logging(message=f'Stack {stack_name} does not exist.', + thread='RemoveStack', + level='warning') + return None + dataset.rename_stack(stack_name, new_stack_name) + + +class ChangePhaseStack(SingletonAction): + name = 'Change phase of stack from dataset' + + @staticmethod + def edit(dataset_name: str, stack_name: str, new_phase: str, **kwargs): + dataset = load_dataset(dataset_name) + if stack_name not in dataset.find_stacks_names(): + napari_formatted_logging(message=f'Stack {stack_name} does not exist.', + thread='RemoveStack', + level='warning') + return None + dataset.change_phase_stack(stack_name, new_phase) + + +class LinkNewLocation(SingletonAction): + name = 'Link dataset to a new folder' + + @staticmethod + def edit(dataset_name: str, new_location: Path, **kwargs): + if dataset_name not in list_datasets(): + napari_formatted_logging(message=f'Dataset {dataset_name} does not exist.', + thread='LinkNewLocation', + level='warning') + return None + + if not new_location.exists(): + napari_formatted_logging(message=f'Location does not exist.', + thread='LinkNewLocation', + level='warning') + return None + + change_dataset_location(dataset_name, new_location) + + +class DeleteDataset(SingletonAction): + name = 'Delete dataset (cannot be undone!)' + + @staticmethod + def edit(dataset_name: str, **kwargs): + if dataset_name not in list_datasets(): + napari_formatted_logging(message=f'Dataset {dataset_name} does not exist.', + thread='DeleteDataset', + level='warning') + return None + dataset_dict = get_dataset_dict(dataset_name) + assert dataset_dict is not None, f'Dataset {dataset_name} does not exist.' + assert 'dataset_dir' in dataset_dict, f'Dataset {dataset_name} does not contain a dataset_dir.' + assert Path(dataset_dict['dataset_dir']).exists(), f'Dataset {dataset_name} does not exist.' + dataset_dir = dataset_dict['dataset_dir'] + delete_dataset(dataset_name, dataset_dir) + + +class DeListDataset(SingletonAction): + name = 'De-List dataset (h5 will not be deleted)' + + @staticmethod + def edit(dataset_name: str, **kwargs): + if dataset_name not in list_datasets(): + napari_formatted_logging(message=f'Dataset {dataset_name} does not exist.', + thread='DeListDataset', + level='warning') + return None + delist_dataset(dataset_name) + + +class PrintDataset(SingletonAction): + name = 'Print dataset information' + + @staticmethod + def edit(dataset_name: str, **kwargs): + if dataset_name not in list_datasets(): + napari_formatted_logging(message=f'Dataset {dataset_name} does not exist.', + thread='PrintDataset', + level='warning') + return None + dataset = load_dataset(dataset_name) + napari_formatted_logging(message=f'Dataset infos:\n{dataset.info()}', + thread='PrintDataset', + level='info') + + +class VisualizeStack(SingletonAction): + name = 'Visualize stack' + + @staticmethod + def edit(viewer: napari.Viewer, dataset_name: str, stack_name: str, **kwargs): + dataset = load_dataset(dataset_name) + if stack_name not in dataset.find_stacks_names(): + napari_formatted_logging(message=f'Stack {stack_name} does not exist.', + thread='RemoveStack', + level='warning') + return None + + stack, result, msg = dataset.get_stack_from_name(stack_name=stack_name) + if result is False: + napari_formatted_logging(message=msg, + thread='OpenStack', + level='warning') + return None + + for spec in stack.list_specs: + image = stack.data[spec.key] + raw_data = image.load_data() + if image.infos is not None: + voxel_size, _, _, unit = image.infos + + else: + voxel_size, unit = [1., 1., 1.], 'µm' + + metadata = {'original_voxel_size': voxel_size, + 'voxel_size_unit': unit, + 'root_name': stack_name} + + image_type = spec.data_type + if image_type == 'image': + viewer.add_image(raw_data, name=spec.key, scale=voxel_size, metadata=metadata) + elif image_type == 'labels': + viewer.add_labels(raw_data, name=spec.key, scale=voxel_size, metadata=metadata) + + +add_stack_action = AddStack() +visualize_stack_action = VisualizeStack() +remove_stack_action = RemoveStack() +rename_stack_action = RenameStack() +change_phase_stack_action = ChangePhaseStack() +print_dataset_action = PrintDataset() +link_new_location_action = LinkNewLocation() +delete_dataset_action = DeleteDataset() +delist_dataset_action = DeListDataset() + +available_actions: dict[str, Action] = {action.name: action for action in [add_stack_action, + visualize_stack_action, + remove_stack_action, + rename_stack_action, + change_phase_stack_action, + print_dataset_action, + link_new_location_action, + delete_dataset_action, + delist_dataset_action]} + + +@magicgui(call_button=list(available_actions.values())[0].name, action={'label': 'Action', - 'choices': list(available_modes.keys()), + 'choices': list(available_actions.keys()), 'tooltip': f'Define if the stack will be added or removed from the dataset'}, dataset_name={'label': 'Dataset name', 'choices': startup_list_datasets, @@ -136,40 +342,85 @@ def _remove_stack(dataset_name, stack_name: str, **kwargs): phase={'label': 'Phase', 'choices': ['train', 'val', 'test'], 'tooltip': f'Define if the stack will be used for training, validation or testing'}, - is_sparse={'label': 'Sparse dataset', - 'tooltip': 'If checked, the dataset will be saved in sparse format.'}, + is_sparse={'label': 'Is Stack sparse?', + 'tooltip': 'If checked, this info will be saved for training.'}, stack_name={'label': 'Stack name', 'tooltip': f'Name of the stack to be added to be edited', 'choices': ['', ]}, + new_stack_name={'label': 'New stack name', + 'tooltip': f'Name of the stack to be added to be edited'}, + new_dataset_location={'label': 'New dataset location', + 'mode': 'd', + 'tooltip': f'New location of the dataset'}, + new_phase={'label': 'New phase', + 'choices': ['train', 'val', 'test'], + 'tooltip': f'Define if the stack will be used for training, validation or testing'}, ) -def widget_edit_dataset(action: str = list(available_modes.keys())[0], +def widget_edit_dataset(viewer: napari.Viewer, + action: str = list(available_actions.keys())[0], dataset_name: str = startup_list_datasets[0], images: list[tuple[str, Optional[Layer]]] = (), phase: str = 'train', is_sparse: bool = False, - stack_name: str = '') -> str: - func = available_modes[action] - kwargs = {'dataset_name': dataset_name, + stack_name: str = '', + new_stack_name: str = '', + new_dataset_location: Path = None, + new_phase: str = 'val', + ) -> str: + action_class = available_actions[action] + + kwargs = {'viewer': viewer, + 'dataset_name': dataset_name, + 'new_stack_name': new_stack_name, + 'new_dataset_location': new_dataset_location, 'images': images, 'phase': phase, 'is_sparse': is_sparse, - 'stack_name': stack_name} - func(**kwargs) + 'stack_name': stack_name, + 'new_phase': new_phase, + } + action_class.edit(**kwargs) + napari_formatted_logging(message=f'Action {action} applied to dataset {dataset_name}.', + thread='EditDataset', + level='info') return action -@widget_create_dataset.called.connect -def update_dataset_name(dataset: DatasetHandler): - if dataset is None: +def safe_load_current_dataset(dataset_name: str = None) -> Optional[DatasetHandler]: + _list_datasets = list_datasets() + if len(_list_datasets) == 0: return None - widget_edit_dataset.dataset_name.choices = list_datasets() + if dataset_name is None: + dataset_name = widget_edit_dataset.dataset_name.value + if dataset_name == empty_dataset[0]: + return None + + if dataset_name not in _list_datasets: + dataset_name = _list_datasets[0] + + return load_dataset(dataset_name) + + +def _update_dataset_name_choices(dataset: DatasetHandler = None): + _list_datasets = list_datasets() + + if len(_list_datasets) == 0: + widget_edit_dataset.dataset_name.choices = empty_dataset + widget_edit_dataset.dataset_name.value = empty_dataset[0] + return None + + if dataset is None: + dataset = safe_load_current_dataset() + + widget_edit_dataset.dataset_name.choices = _list_datasets widget_edit_dataset.dataset_name.value = dataset.name def _update_stack_name_choices(dataset: DatasetHandler = None): if dataset is None: - dataset = load_dataset(widget_edit_dataset.dataset_name.value) + dataset = safe_load_current_dataset() + stacks_options = dataset.find_stacks_names() stacks_options = stacks_options if stacks_options else [''] widget_edit_dataset.stack_name.choices = stacks_options @@ -177,11 +428,8 @@ def _update_stack_name_choices(dataset: DatasetHandler = None): def _update_images_choices(dataset: DatasetHandler = None): - if len(list_datasets()) == 0: - return None - if dataset is None: - dataset = load_dataset(widget_edit_dataset.dataset_name.value) + dataset = safe_load_current_dataset() images_default = [] for image in dataset.expected_stack_specs.list_specs: @@ -190,59 +438,142 @@ def _update_images_choices(dataset: DatasetHandler = None): widget_edit_dataset.images.value = images_default -_update_images_choices() +def _on_action_changed_add_stack(): + widget_edit_dataset.dataset_name.show() + widget_edit_dataset.phase.show() + widget_edit_dataset.is_sparse.show() + widget_edit_dataset.images.show() + widget_edit_dataset.new_stack_name.show() + widget_edit_dataset.new_stack_name.value = '' -@widget_edit_dataset.dataset_name.changed.connect -def update_dataset_name(dataset_name: str): - dataset = load_dataset(dataset_name) - widget_edit_dataset.phase.choices = dataset.default_phases - widget_edit_dataset.is_sparse.value = dataset.is_sparse + current_dataset = safe_load_current_dataset() + if current_dataset is None: + return None + _update_dataset_name_choices(current_dataset) + _update_images_choices(current_dataset) - _update_images_choices(dataset) - _update_stack_name_choices(dataset) +add_stack_action.on_action_changed = _on_action_changed_add_stack -def _add_stack_callback(): - widget_edit_dataset.images.show() - widget_edit_dataset.phase.show() - widget_edit_dataset.stack_name.hide() +def _on_action_changed_remove_stack(): + widget_edit_dataset.stack_name.show() + widget_edit_dataset.dataset_name.show() -def _remove_stack_callback(): - widget_edit_dataset.images.hide() + current_dataset = safe_load_current_dataset() + if current_dataset is None: + return None + _update_dataset_name_choices(current_dataset) + _update_stack_name_choices(current_dataset) + + +def _on_action_rename_stack(): + widget_edit_dataset.stack_name.show() + widget_edit_dataset.dataset_name.show() + widget_edit_dataset.new_stack_name.show() + + current_dataset = safe_load_current_dataset() + if current_dataset is None: + return None + _update_dataset_name_choices(current_dataset) + _update_stack_name_choices(current_dataset) + + +def _on_action_change_phase_stack(): + widget_edit_dataset.stack_name.show() + widget_edit_dataset.dataset_name.show() + widget_edit_dataset.new_phase.show() + + current_dataset = safe_load_current_dataset() + if current_dataset is None: + return None + _update_dataset_name_choices(current_dataset) + _update_stack_name_choices(current_dataset) + + +visualize_stack_action.on_action_changed = _on_action_changed_remove_stack +visualize_stack_action.on_edit_done = _update_stack_name_choices + +rename_stack_action.on_action_changed = _on_action_rename_stack +rename_stack_action.on_edit_done = _update_stack_name_choices + +remove_stack_action.on_action_changed = _on_action_changed_remove_stack +remove_stack_action.on_edit_done = _update_stack_name_choices + +change_phase_stack_action.on_action_changed = _on_action_change_phase_stack +change_phase_stack_action.on_edit_done = _update_stack_name_choices + + +def _on_action_changed_delete_dataset(): + widget_edit_dataset.dataset_name.show() + + current_dataset = safe_load_current_dataset() + if current_dataset is None: + return None + _update_dataset_name_choices(current_dataset) + + +delete_dataset_action.on_action_changed = _on_action_changed_delete_dataset +delete_dataset_action.on_edit_done = _update_dataset_name_choices + +delist_dataset_action.on_action_changed = _on_action_changed_delete_dataset +delist_dataset_action.on_edit_done = _update_dataset_name_choices + +print_dataset_action.on_action_changed = _on_action_changed_delete_dataset +print_dataset_action.on_edit_done = _update_dataset_name_choices + + +def _hide_all(): + widget_edit_dataset.dataset_name.hide() + widget_edit_dataset.stack_name.hide() + widget_edit_dataset.new_stack_name.hide() + widget_edit_dataset.new_dataset_location.hide() + widget_edit_dataset.dataset_name.hide() widget_edit_dataset.phase.hide() widget_edit_dataset.is_sparse.hide() - widget_edit_dataset.stack_name.show() + widget_edit_dataset.images.hide() + widget_edit_dataset.new_phase.hide() - _update_stack_name_choices() +@widget_create_dataset.called.connect +def _create_dataset_is_called(dataset: DatasetHandler): + _update_dataset_name_choices(dataset) + _update_stack_name_choices(dataset) -_add_stack_callback() + widget_create_dataset.dataset_name.value = '' -available_modes_callbacks = { - 'Add stack to dataset': _add_stack_callback, - 'Remove stack from dataset': _remove_stack_callback, -} +@widget_edit_dataset.dataset_name.changed.connect +def update_dataset_name(dataset_name: str): + dataset = safe_load_current_dataset(dataset_name) + if dataset is None: + return None + + widget_edit_dataset.is_sparse.value = dataset.is_sparse + _update_stack_name_choices(dataset) + _update_images_choices(dataset) @widget_edit_dataset.action.changed.connect -def update_mode(action: str): - if action in available_modes_callbacks.keys(): - available_modes_callbacks[action]() +def update_action(action: str): + _hide_all() + if action in available_actions.keys(): + action_class = available_actions[action] + action_class.on_action_changed() + widget_edit_dataset.call_button.text = action_class.name -def _remove_stack_update_choices(): - _update_stack_name_choices() +@widget_edit_dataset.called.connect +def _on_done(action: str): + if action is None: + return None -available_actions_on_done = { - 'Remove stack from dataset': _remove_stack_update_choices -} + if action in available_actions.keys(): + action_class = available_actions[action] + action_class.on_edit_done() -@widget_edit_dataset.called.connect -def update_state_after_edit(action: str): - if action in available_actions_on_done.keys(): - available_actions_on_done[action]() +_hide_all() +list(available_actions.values())[0].on_action_changed() From c6e807c2cdc7798b53f87738171113cbfabfea00 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Aug 2023 16:36:53 +0200 Subject: [PATCH 19/22] homogenize new dataset handler with training widget --- plantseg/dataset_tools/dataset_handler.py | 5 +++ plantseg/ui/widgets/dataset_tools.py | 2 +- plantseg/ui/widgets/training.py | 44 ++++++++++++++++++----- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/plantseg/dataset_tools/dataset_handler.py b/plantseg/dataset_tools/dataset_handler.py index bed637b5..fcd83fee 100644 --- a/plantseg/dataset_tools/dataset_handler.py +++ b/plantseg/dataset_tools/dataset_handler.py @@ -68,6 +68,7 @@ class DatasetHandler: train: list[str] val: list[str] test: list[str] + dimensionality: str default_file_formats = H5_EXTENSIONS def __init__(self, @@ -95,6 +96,10 @@ def init_datastructure(self): for phase in self.default_phases: (self.dataset_dir / phase).mkdir(exist_ok=True) + @property + def dimensionality(self) -> str: + return self.expected_stack_specs.dimensionality + @classmethod def from_dict(cls, dataset_dict: dict): """ diff --git a/plantseg/ui/widgets/dataset_tools.py b/plantseg/ui/widgets/dataset_tools.py index df9832ea..c59d3e10 100644 --- a/plantseg/ui/widgets/dataset_tools.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -338,7 +338,7 @@ def edit(viewer: napari.Viewer, dataset_name: str, stack_name: str, **kwargs): 'tooltip': f'Define if the stack will be added or removed from the dataset'}, dataset_name={'label': 'Dataset name', 'choices': startup_list_datasets, - 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, + 'tooltip': f'Choose the dataset to be edited'}, phase={'label': 'Phase', 'choices': ['train', 'val', 'test'], 'tooltip': f'Define if the stack will be used for training, validation or testing'}, diff --git a/plantseg/ui/widgets/training.py b/plantseg/ui/widgets/training.py index 6aaef7c9..fb83a73e 100644 --- a/plantseg/ui/widgets/training.py +++ b/plantseg/ui/widgets/training.py @@ -1,5 +1,4 @@ from concurrent.futures import Future -from pathlib import Path from typing import Tuple from magicgui import magicgui @@ -8,10 +7,15 @@ from plantseg import PLANTSEG_MODELS_DIR from plantseg.training.train import unet_training -from plantseg.utils import list_all_dimensionality from plantseg.ui.widgets.predictions import ALL_DEVICES -from plantseg.ui.widgets.utils import create_layer_name, start_threading_process, return_value_if_widget, \ - layer_properties +from plantseg.ui.widgets.utils import create_layer_name, start_threading_process, return_value_if_widget +from plantseg.utils import list_all_dimensionality +from plantseg.utils import list_datasets +from plantseg.dataset_tools.dataset_handler import load_dataset +from plantseg.ui.logging import napari_formatted_logging + +empty_dataset = ['none'] +startup_list_datasets = list_datasets() or empty_dataset def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, patch_size, max_num_iters, dimensionality, @@ -25,9 +29,9 @@ def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, pa @magicgui(call_button='Run Training', - dataset_dir={'label': 'Path to the dataset directory', - 'mode': 'd', - 'tooltip': 'Select a directory containing train and val subfolders'}, + dataset_name={'label': 'Dataset name', + 'choices': startup_list_datasets, + 'tooltip': f'Choose the dataset for the training.'}, model_name={'label': 'Trained model name', 'tooltip': f'Model files will be saved in f{PLANTSEG_MODELS_DIR}/model_name'}, in_channels={'label': 'Input channels', @@ -48,7 +52,7 @@ def unet_training_wrapper(dataset_dir, model_name, in_channels, out_channels, pa 'choices': ALL_DEVICES} ) def widget_unet_training(viewer: Viewer, - dataset_dir: Path = Path.home(), + dataset_name: str = startup_list_datasets[0], model_name: str = 'my-model', in_channels: int = 1, out_channels: int = 1, @@ -59,6 +63,12 @@ def widget_unet_training(viewer: Viewer, device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: out_name = create_layer_name(model_name, 'training') step_kwargs = dict(model_name=model_name, sparse=sparse, dimensionality=dimensionality) + if dataset_name == empty_dataset[0]: + raise ValueError('Please select a dataset, if you do not have one, please create one using the napari ' + 'dataset creator widget.') + + dataset_dir = load_dataset(dataset_name).dataset_dir + return start_threading_process(unet_training_wrapper, runtime_kwargs={ 'dataset_dir': dataset_dir, @@ -82,6 +92,19 @@ def widget_unet_training(viewer: Viewer, ) +@widget_unet_training.dataset_name.changed.connect +def _on_dataset_name_changed(dataset_name: str): + if dataset_name == empty_dataset[0]: + return None + + dataset = load_dataset(dataset_name) + widget_unet_training.sparse.value = dataset.is_sparse + widget_unet_training.dimensionality.value = dataset.dimensionality + for spec in dataset.expected_stack_specs.list_specs: + if spec.key == 'raw': + widget_unet_training.in_channels.value = spec.num_channels + + @widget_unet_training.dimensionality.changed.connect def _on_dimensionality_changed(dimensionality: str): dimensionality = return_value_if_widget(dimensionality) @@ -100,3 +123,8 @@ def _on_sparse_change(sparse: bool): widget_unet_training.out_channels.value = 8 else: widget_unet_training.out_channels.value = 1 + + +if startup_list_datasets[0] != empty_dataset[0]: + _on_dataset_name_changed(startup_list_datasets[0]) + From d1eb7f40b9b9dc7b229c5cd0db435a46b0971e4d Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Aug 2023 17:18:43 +0200 Subject: [PATCH 20/22] improve default and clean ui --- plantseg/ui/containers.py | 27 +++---- plantseg/ui/viewer.py | 7 +- plantseg/ui/widgets/dataset_tools.py | 16 ++-- plantseg/ui/widgets/predictions.py | 90 ++++++++++++++++++++-- plantseg/ui/widgets/segmentation.py | 110 +++++++++++++++++++++++---- 5 files changed, 205 insertions(+), 45 deletions(-) diff --git a/plantseg/ui/containers.py b/plantseg/ui/containers.py index e6ed0161..408de8a0 100644 --- a/plantseg/ui/containers.py +++ b/plantseg/ui/containers.py @@ -6,16 +6,16 @@ from plantseg.ui.widgets.dataprocessing import widget_cropping, widget_add_layers from plantseg.ui.widgets.dataprocessing import widget_label_processing from plantseg.ui.widgets.dataprocessing import widget_rescaling, widget_gaussian_smoothing +from plantseg.ui.widgets.dataset_tools import widget_create_dataset, widget_edit_dataset from plantseg.ui.widgets.io import open_file, export_stacks -from plantseg.ui.widgets.predictions import widget_iterative_unet_predictions, widget_add_custom_model from plantseg.ui.widgets.predictions import widget_unet_predictions, widget_test_all_unet_predictions +from plantseg.ui.widgets.predictions import widget_iterative_unet_predictions, widget_add_custom_model from plantseg.ui.widgets.proofreading.proofreading import widget_clean_scribble, widget_filter_segmentation from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.ui.widgets.segmentation import widget_dt_ws, widget_agglomeration -from plantseg.ui.widgets.segmentation import widget_fix_over_under_segmentation_from_nuclei -from plantseg.ui.widgets.segmentation import widget_lifted_multicut from plantseg.ui.widgets.segmentation import widget_simple_dt_ws -from plantseg.ui.widgets.dataset_tools import widget_create_dataset, widget_edit_dataset +from plantseg.ui.widgets.segmentation import widget_lifted_multicut +from plantseg.ui.widgets.segmentation import widget_fix_over_under_segmentation_from_nuclei def setup_menu(container, path=None): @@ -72,19 +72,14 @@ def get_dataset_workflow(): return container -def get_extra_seg(): - container = MainWindow(widgets=[widget_dt_ws, - widget_lifted_multicut, - widget_fix_over_under_segmentation_from_nuclei], - labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Extra-Seg') - return container - - -def get_extra_pred(): +def get_extra(): container = MainWindow(widgets=[widget_test_all_unet_predictions, widget_iterative_unet_predictions, - widget_add_custom_model], + widget_add_custom_model, + widget_dt_ws, + widget_lifted_multicut, + widget_fix_over_under_segmentation_from_nuclei + ], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Extra-Pred') + container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Extra-Worflows') return container diff --git a/plantseg/ui/viewer.py b/plantseg/ui/viewer.py index ea6c0f40..ce1e566e 100644 --- a/plantseg/ui/viewer.py +++ b/plantseg/ui/viewer.py @@ -1,6 +1,6 @@ import napari -from plantseg.ui.containers import get_extra_seg, get_extra_pred +from plantseg.ui.containers import get_extra from plantseg.ui.containers import get_gasp_workflow, get_preprocessing_workflow, get_main from plantseg.ui.containers import get_dataset_workflow from plantseg.ui.logging import napari_formatted_logging @@ -16,9 +16,8 @@ def run_viewer(): for _containers, name in [(get_preprocessing_workflow(), 'Data - Processing'), (get_gasp_workflow(), 'UNet + Segmentation'), - (get_dataset_workflow(), 'Dataset'), - (get_extra_pred(), 'Extra-Pred'), - (get_extra_seg(), 'Extra-Seg'), + (get_dataset_workflow(), 'Datasets Management'), + (get_extra(), 'Extra-Workflows'), ]: _container_w = viewer.window.add_dock_widget(_containers, name=name) viewer.window._qt_window.tabifyDockWidget(main_w, _container_w) diff --git a/plantseg/ui/widgets/dataset_tools.py b/plantseg/ui/widgets/dataset_tools.py index c59d3e10..0a379ba6 100644 --- a/plantseg/ui/widgets/dataset_tools.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -25,7 +25,7 @@ class ImageType(Enum): LABELS: str = 'labels' -@magicgui(call_button='Initialize Dataset', +@magicgui(call_button='Initialize New Dataset', dataset_name={'label': 'Dataset name', 'tooltip': f'Initialize an empty dataset with name model_name'}, dataset_dir={'label': 'Path to the dataset directory', @@ -318,8 +318,8 @@ def edit(viewer: napari.Viewer, dataset_name: str, stack_name: str, **kwargs): change_phase_stack_action = ChangePhaseStack() print_dataset_action = PrintDataset() link_new_location_action = LinkNewLocation() -delete_dataset_action = DeleteDataset() delist_dataset_action = DeListDataset() +delete_dataset_action = DeleteDataset() available_actions: dict[str, Action] = {action.name: action for action in [add_stack_action, visualize_stack_action, @@ -328,8 +328,9 @@ def edit(viewer: napari.Viewer, dataset_name: str, stack_name: str, **kwargs): change_phase_stack_action, print_dataset_action, link_new_location_action, + delist_dataset_action, delete_dataset_action, - delist_dataset_action]} + ]} @magicgui(call_button=list(available_actions.values())[0].name, @@ -542,6 +543,7 @@ def _create_dataset_is_called(dataset: DatasetHandler): _update_stack_name_choices(dataset) widget_create_dataset.dataset_name.value = '' + widget_edit_dataset.show() @widget_edit_dataset.dataset_name.changed.connect @@ -575,5 +577,9 @@ def _on_done(action: str): action_class.on_edit_done() -_hide_all() -list(available_actions.values())[0].on_action_changed() +if startup_list_datasets == empty_dataset: + widget_edit_dataset.hide() + +else: + _hide_all() + list(available_actions.values())[0].on_action_changed() diff --git a/plantseg/ui/widgets/predictions.py b/plantseg/ui/widgets/predictions.py index f6d9a71b..cc38a559 100644 --- a/plantseg/ui/widgets/predictions.py +++ b/plantseg/ui/widgets/predictions.py @@ -12,13 +12,13 @@ from plantseg.dataprocessing.functional import image_gaussian_smoothing from plantseg.predictions.functional import unet_predictions -from plantseg.utils import list_all_modality, list_all_dimensionality, list_all_output_type -from plantseg.utils import list_models, add_custom_model, get_train_config, get_model_zoo, get_model_description from plantseg.ui.logging import napari_formatted_logging from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.ui.widgets.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.ui.widgets.utils import return_value_if_widget from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.utils import list_all_modality, list_all_dimensionality, list_all_output_type +from plantseg.utils import list_models, add_custom_model, get_train_config, get_model_zoo, get_model_description ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] @@ -180,11 +180,14 @@ def _compute_multiple_predictions(image, patch_size, device): patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, device={'label': 'Device', - 'choices': ALL_DEVICES} + 'choices': ALL_DEVICES}, + show_widget={'label': 'Show Widget: Try all Available Models', + 'tooltip': 'Show the widget to try all available models.'} ) def widget_test_all_unet_predictions(image: Image, patch_size: Tuple[int, int, int] = (80, 170, 170), - device: str = ALL_DEVICES[0]) -> Future[List[LayerDataTuple]]: + device: str = ALL_DEVICES[0], + show_widget: bool = False) -> Future[List[LayerDataTuple]]: func = thread_worker(partial(_compute_multiple_predictions, image=image, patch_size=patch_size, @@ -201,6 +204,23 @@ def on_done(result): return future +@widget_test_all_unet_predictions.show_widget.changed.connect +def _on_show_all_unet_predictions(show_widget: bool): + if show_widget: + widget_test_all_unet_predictions.image.show() + widget_test_all_unet_predictions.patch_size.show() + widget_test_all_unet_predictions.device.show() + widget_test_all_unet_predictions.call_button.show() + else: + widget_test_all_unet_predictions.image.hide() + widget_test_all_unet_predictions.patch_size.hide() + widget_test_all_unet_predictions.device.hide() + widget_test_all_unet_predictions.call_button.hide() + + +_on_show_all_unet_predictions(True) + + def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device): func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode, device=device) @@ -231,7 +251,9 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc single_patch={'label': 'Single Patch', 'tooltip': 'If True, a single patch will be processed at a time to save memory.'}, device={'label': 'Device', - 'choices': ALL_DEVICES} + 'choices': ALL_DEVICES}, + show_widget={'label': 'Show Widget: Iterative Predictions', + 'tooltip': 'Show the widget to try all available models.'} ) def widget_iterative_unet_predictions(image: Image, model_name: str, @@ -239,7 +261,8 @@ def widget_iterative_unet_predictions(image: Image, sigma: float = 1.0, patch_size: Tuple[int, int, int] = (80, 170, 170), single_patch: bool = True, - device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: + device: str = ALL_DEVICES[0], + show_widget: bool = False) -> Future[LayerDataTuple]: out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}') inputs_names = (image.name,) layer_kwargs = layer_properties(name=out_name, @@ -265,6 +288,31 @@ def widget_iterative_unet_predictions(image: Image, ) +@widget_iterative_unet_predictions.show_widget.changed.connect +def _on_show_iterative_unet_predictions(show_widget: bool): + if show_widget: + widget_iterative_unet_predictions.image.show() + widget_iterative_unet_predictions.model_name.show() + widget_iterative_unet_predictions.num_iterations.show() + widget_iterative_unet_predictions.sigma.show() + widget_iterative_unet_predictions.patch_size.show() + widget_iterative_unet_predictions.single_patch.show() + widget_iterative_unet_predictions.device.show() + widget_iterative_unet_predictions.call_button.show() + else: + widget_iterative_unet_predictions.image.hide() + widget_iterative_unet_predictions.model_name.hide() + widget_iterative_unet_predictions.num_iterations.hide() + widget_iterative_unet_predictions.sigma.hide() + widget_iterative_unet_predictions.patch_size.hide() + widget_iterative_unet_predictions.single_patch.hide() + widget_iterative_unet_predictions.device.hide() + widget_iterative_unet_predictions.call_button.hide() + + +_on_show_iterative_unet_predictions(False) + + @widget_iterative_unet_predictions.model_name.changed.connect def _on_model_name_changed_iterative(model_name: str): model_name = return_value_if_widget(model_name) @@ -292,6 +340,8 @@ def _on_model_name_changed_iterative(model_name: str): 'widget_type': 'ComboBox', 'tooltip': 'Type of prediction (e.g. cell boundaries predictions or nuclei...).', 'choices': list_all_output_type()}, + show_widget={'label': 'Show Widget: Add Custom Model', + 'tooltip': 'Show the widget to add a new custom model.'} ) def widget_add_custom_model(new_model_name: str = 'custom_model', @@ -300,7 +350,8 @@ def widget_add_custom_model(new_model_name: str = 'custom_model', description: str = 'New custom model', dimensionality: str = list_all_dimensionality()[0], modality: str = list_all_modality()[0], - output_type: str = list_all_output_type()[0]) -> None: + output_type: str = list_all_output_type()[0], + show_widget: bool = False) -> None: finished, error_msg = add_custom_model(new_model_name=new_model_name, location=model_location, resolution=resolution, @@ -321,6 +372,31 @@ def widget_add_custom_model(new_model_name: str = 'custom_model', thread='Add Custom Model') +@widget_add_custom_model.show_widget.changed.connect +def _on_show_widget_add_custom_model(show_widget: bool): + if show_widget: + widget_add_custom_model.new_model_name.show() + widget_add_custom_model.model_location.show() + widget_add_custom_model.resolution.show() + widget_add_custom_model.description.show() + widget_add_custom_model.dimensionality.show() + widget_add_custom_model.modality.show() + widget_add_custom_model.output_type.show() + widget_add_custom_model.call_button.show() + else: + widget_add_custom_model.new_model_name.hide() + widget_add_custom_model.model_location.hide() + widget_add_custom_model.resolution.hide() + widget_add_custom_model.description.hide() + widget_add_custom_model.dimensionality.hide() + widget_add_custom_model.modality.hide() + widget_add_custom_model.output_type.hide() + widget_add_custom_model.call_button.hide() + + +_on_show_widget_add_custom_model(False) + + @widget_add_custom_model.called.connect def _on_add_custom_model_called(): widget_unet_predictions.model_name.choices = list_models() diff --git a/plantseg/ui/widgets/segmentation.py b/plantseg/ui/widgets/segmentation.py index a49f5356..b5ae78a2 100644 --- a/plantseg/ui/widgets/segmentation.py +++ b/plantseg/ui/widgets/segmentation.py @@ -3,16 +3,16 @@ from typing import Tuple, Callable from magicgui import magicgui +from napari import Viewer from napari.layers import Labels, Image, Layer from napari.types import LayerDataTuple -from napari import Viewer from plantseg.dataprocessing.functional.advanced_dataprocessing import fix_over_under_segmentation_from_nuclei from plantseg.dataprocessing.functional.dataprocessing import normalize_01 from plantseg.segmentation.functional import gasp, multicut, dt_watershed, mutex_ws from plantseg.segmentation.functional import lifted_multicut_from_nuclei_segmentation, lifted_multicut_from_nuclei_pmaps -from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.ui.logging import napari_formatted_logging +from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties @@ -97,19 +97,22 @@ def widget_agglomeration(viewer: Viewer, 'tooltip': 'Raw or boundary image to use as input for clustering.'}, nuclei={'label': 'Nuclei', 'tooltip': 'Nuclei binary predictions or Nuclei segmentation.'}, - _labels={'label': 'Over-segmentation', - 'tooltip': 'Over-segmentation labels layer to use as input for clustering.'}, + im_labels={'label': 'Over-segmentation', + 'tooltip': 'Over-segmentation labels layer to use as input for clustering.'}, beta={'label': 'Under/Over segmentation factor', 'tooltip': 'A low value will increase under-segmentation tendency ' 'and a large value increase over-segmentation tendency.', 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}, minsize={'label': 'Min-size', - 'tooltip': 'Minimum segment size allowed in voxels.'}) + 'tooltip': 'Minimum segment size allowed in voxels.'}, + show_widget={'label': 'Show Widget: Lifted MultiCut', + 'tooltip': 'Show the widget for the lifted multicut.'}) def widget_lifted_multicut(image: Image, nuclei: Layer, - _labels: Labels, + im_labels: Labels, beta: float = 0.5, - minsize: int = 100) -> Future[LayerDataTuple]: + minsize: int = 100, + show_widget: bool = False) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn('Lifted MultiCut Widget') @@ -123,7 +126,7 @@ def widget_lifted_multicut(image: Image, raise ValueError(f'{nuclei} must be either an image or a labels layer') out_name = create_layer_name(image.name, 'LiftedMultiCut') - inputs_names = (image.name, nuclei.name, _labels.name) + inputs_names = (image.name, nuclei.name, im_labels.name) layer_kwargs = layer_properties(name=out_name, scale=image.scale, metadata=image.metadata) @@ -133,7 +136,7 @@ def widget_lifted_multicut(image: Image, return start_threading_process(lmc, runtime_kwargs={'boundary_pmaps': image.data, extra_key: nuclei.data, - 'superpixels': _labels.data}, + 'superpixels': im_labels.data}, statics_kwargs=step_kwargs, out_name=out_name, input_keys=inputs_names, @@ -143,6 +146,27 @@ def widget_lifted_multicut(image: Image, ) +@widget_lifted_multicut.show_widget.changed.connect +def _on_show_lifted_multicut(show_widget: bool): + if show_widget: + widget_lifted_multicut.image.show() + widget_lifted_multicut.nuclei.show() + widget_lifted_multicut.im_labels.show() + widget_lifted_multicut.beta.show() + widget_lifted_multicut.minsize.show() + widget_lifted_multicut.call_button.show() + else: + widget_lifted_multicut.image.hide() + widget_lifted_multicut.nuclei.hide() + widget_lifted_multicut.im_labels.hide() + widget_lifted_multicut.beta.hide() + widget_lifted_multicut.minsize.hide() + widget_lifted_multicut.call_button.hide() + + +_on_show_lifted_multicut(False) + + def dtws_wrapper(boundary_pmaps, stacked: bool = True, threshold: float = 0.5, @@ -194,7 +218,9 @@ def dtws_wrapper(boundary_pmaps, use_pixel_pitch={'label': 'Use pixel pitch'}, pixel_pitch={'label': 'Pixel pitch'}, apply_nonmax_suppression={'label': 'Apply nonmax suppression'}, - nuclei={'label': 'Is image Nuclei'} + nuclei={'label': 'Is image Nuclei'}, + show_widget={'label': 'Show widget: Advanced Watershed', + 'tooltip': 'Show the widget to run Watershed.'} ) def widget_dt_ws(image: Image, stacked: str = '2D', @@ -206,7 +232,8 @@ def widget_dt_ws(image: Image, use_pixel_pitch: bool = False, pixel_pitch: Tuple[int, int, int] = (1, 1, 1), apply_nonmax_suppression: bool = False, - nuclei: bool = False) -> Future[LayerDataTuple]: + nuclei: bool = False, + show_widget: bool = False) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn("Watershed Widget") @@ -240,6 +267,39 @@ def widget_dt_ws(image: Image, ) +@widget_dt_ws.show_widget.changed.connect +def _show_widget_dt_ws(show_widget: bool): + if show_widget: + widget_dt_ws.image.show() + widget_dt_ws.stacked.show() + widget_dt_ws.threshold.show() + widget_dt_ws.min_size.show() + widget_dt_ws.sigma_seeds.show() + widget_dt_ws.sigma_weights.show() + widget_dt_ws.alpha.show() + widget_dt_ws.use_pixel_pitch.show() + widget_dt_ws.pixel_pitch.show() + widget_dt_ws.apply_nonmax_suppression.show() + widget_dt_ws.nuclei.show() + widget_dt_ws.call_button.show() + else: + widget_dt_ws.image.hide() + widget_dt_ws.stacked.hide() + widget_dt_ws.threshold.hide() + widget_dt_ws.min_size.hide() + widget_dt_ws.sigma_seeds.hide() + widget_dt_ws.sigma_weights.hide() + widget_dt_ws.alpha.hide() + widget_dt_ws.use_pixel_pitch.hide() + widget_dt_ws.pixel_pitch.hide() + widget_dt_ws.apply_nonmax_suppression.hide() + widget_dt_ws.nuclei.hide() + widget_dt_ws.call_button.hide() + + +_show_widget_dt_ws(False) + + @magicgui(call_button='Run Watershed', image={'label': 'Pmap/Image', 'tooltip': 'Raw or boundary image to use as input for Watershed.'}, @@ -294,12 +354,15 @@ def widget_simple_dt_ws(image: Image, threshold={'label': 'Threshold', 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 0.1}, quantile={'label': 'Nuclei Quantile', - 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 0.1}) + 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 0.1}, + show_widget={'label': 'Show Widget: Fix Segmentation from Nuclei', + 'tooltip': 'Show/Hide the widget to change the parameters of the segmentation.'}) def widget_fix_over_under_segmentation_from_nuclei(cell_segmentation: Labels, nuclei_segmentation: Labels, boundary_pmaps: Image, threshold=(33, 66), - quantile=(0.1, 99.9)) -> Future[LayerDataTuple]: + quantile=(0.1, 99.9), + show_widget: bool = False) -> Future[LayerDataTuple]: out_name = create_layer_name(cell_segmentation.name, 'NucleiSegFix') threshold_merge, threshold_split = threshold threshold_merge, threshold_split = threshold_merge / 100, threshold_split / 100 @@ -332,3 +395,24 @@ def widget_fix_over_under_segmentation_from_nuclei(cell_segmentation: Labels, layer_type=layer_type, step_name=f'Fix Over / Under segmentation', ) + + +@widget_fix_over_under_segmentation_from_nuclei.show_widget.changed.connect +def _show_widget_fix_over_under_segmentation_from_nuclei(show_widget: bool): + if show_widget: + widget_fix_over_under_segmentation_from_nuclei.cell_segmentation.show() + widget_fix_over_under_segmentation_from_nuclei.nuclei_segmentation.show() + widget_fix_over_under_segmentation_from_nuclei.boundary_pmaps.show() + widget_fix_over_under_segmentation_from_nuclei.threshold.show() + widget_fix_over_under_segmentation_from_nuclei.quantile.show() + widget_fix_over_under_segmentation_from_nuclei.call_button.show() + else: + widget_fix_over_under_segmentation_from_nuclei.cell_segmentation.hide() + widget_fix_over_under_segmentation_from_nuclei.nuclei_segmentation.hide() + widget_fix_over_under_segmentation_from_nuclei.boundary_pmaps.hide() + widget_fix_over_under_segmentation_from_nuclei.threshold.hide() + widget_fix_over_under_segmentation_from_nuclei.quantile.hide() + widget_fix_over_under_segmentation_from_nuclei.call_button.hide() + + +_show_widget_fix_over_under_segmentation_from_nuclei(False) From 6ead29fb2e6423bdad37dfa84aa2568942556589 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Aug 2023 17:41:17 +0200 Subject: [PATCH 21/22] cleanup formatting and imports --- plantseg/dataset_tools/dataset_handler.py | 1 - plantseg/dataset_tools/images.py | 4 ++-- plantseg/ui/containers.py | 6 +++--- plantseg/ui/viewer.py | 2 +- plantseg/ui/widgets/dataprocessing.py | 2 +- plantseg/ui/widgets/dataset_tools.py | 1 - plantseg/ui/widgets/training.py | 4 +--- 7 files changed, 8 insertions(+), 12 deletions(-) diff --git a/plantseg/dataset_tools/dataset_handler.py b/plantseg/dataset_tools/dataset_handler.py index fcd83fee..407caad8 100644 --- a/plantseg/dataset_tools/dataset_handler.py +++ b/plantseg/dataset_tools/dataset_handler.py @@ -354,7 +354,6 @@ def change_phase_stack(self, stack_name: str, new_phase: str): raise ValueError(f'Stack {stack_name} not found in dataset {self.name}.') - def load_dataset(dataset_name: str) -> DatasetHandler: """ Load a dataset from the user dataset config file. diff --git a/plantseg/dataset_tools/images.py b/plantseg/dataset_tools/images.py index eb2480e3..2658762a 100644 --- a/plantseg/dataset_tools/images.py +++ b/plantseg/dataset_tools/images.py @@ -263,7 +263,7 @@ def __init__(self, data: np.ndarray, elif data.ndim == 4: num_channels = data.shape[0] assert num_channels == spec.num_channels, (f'Invalid shape for 2D image: expected number of channels ' - f'{spec.num_channels}, got {num_channels}') + f'{spec.num_channels}, got {num_channels}') assert data.shape[1] == 1, (f'Invalid shape for 2D image: {data.shape}, expected (C, 1, X, Y), ' f'got (c, {data.shape[1]}, x, y') layout = 'c1xy' @@ -277,7 +277,7 @@ def __init__(self, data: np.ndarray, elif data.ndim == 4: num_channels = data.shape[0] assert num_channels == spec.num_channels, (f'Invalid shape for 3D image: expected number of channels ' - f'{spec.num_channels}, got {num_channels}') + f'{spec.num_channels}, got {num_channels}') layout = 'cxyz' else: raise ValueError(f'Invalid number of dimensions: {data.ndim}, expected 3 or 4.') diff --git a/plantseg/ui/containers.py b/plantseg/ui/containers.py index 408de8a0..df248bd3 100644 --- a/plantseg/ui/containers.py +++ b/plantseg/ui/containers.py @@ -8,14 +8,14 @@ from plantseg.ui.widgets.dataprocessing import widget_rescaling, widget_gaussian_smoothing from plantseg.ui.widgets.dataset_tools import widget_create_dataset, widget_edit_dataset from plantseg.ui.widgets.io import open_file, export_stacks -from plantseg.ui.widgets.predictions import widget_unet_predictions, widget_test_all_unet_predictions from plantseg.ui.widgets.predictions import widget_iterative_unet_predictions, widget_add_custom_model +from plantseg.ui.widgets.predictions import widget_unet_predictions, widget_test_all_unet_predictions from plantseg.ui.widgets.proofreading.proofreading import widget_clean_scribble, widget_filter_segmentation from plantseg.ui.widgets.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.ui.widgets.segmentation import widget_dt_ws, widget_agglomeration -from plantseg.ui.widgets.segmentation import widget_simple_dt_ws -from plantseg.ui.widgets.segmentation import widget_lifted_multicut from plantseg.ui.widgets.segmentation import widget_fix_over_under_segmentation_from_nuclei +from plantseg.ui.widgets.segmentation import widget_lifted_multicut +from plantseg.ui.widgets.segmentation import widget_simple_dt_ws def setup_menu(container, path=None): diff --git a/plantseg/ui/viewer.py b/plantseg/ui/viewer.py index ce1e566e..a889f518 100644 --- a/plantseg/ui/viewer.py +++ b/plantseg/ui/viewer.py @@ -1,8 +1,8 @@ import napari +from plantseg.ui.containers import get_dataset_workflow from plantseg.ui.containers import get_extra from plantseg.ui.containers import get_gasp_workflow, get_preprocessing_workflow, get_main -from plantseg.ui.containers import get_dataset_workflow from plantseg.ui.logging import napari_formatted_logging from plantseg.ui.widgets.proofreading.proofreading import setup_proofreading_keybindings diff --git a/plantseg/ui/widgets/dataprocessing.py b/plantseg/ui/widgets/dataprocessing.py index 8226b798..0d339946 100644 --- a/plantseg/ui/widgets/dataprocessing.py +++ b/plantseg/ui/widgets/dataprocessing.py @@ -12,11 +12,11 @@ from plantseg.dataprocessing.functional.dataprocessing import compute_scaling_factor, compute_scaling_voxelsize from plantseg.dataprocessing.functional.labelprocessing import relabel_segmentation as _relabel_segmentation from plantseg.dataprocessing.functional.labelprocessing import set_background_to_value -from plantseg.utils import list_models, get_model_resolution from plantseg.ui.widgets.predictions import widget_unet_predictions from plantseg.ui.widgets.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.ui.widgets.utils import return_value_if_widget from plantseg.ui.widgets.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.utils import list_models, get_model_resolution @magicgui(call_button='Run Gaussian Smoothing', diff --git a/plantseg/ui/widgets/dataset_tools.py b/plantseg/ui/widgets/dataset_tools.py index 0a379ba6..2d3fa019 100644 --- a/plantseg/ui/widgets/dataset_tools.py +++ b/plantseg/ui/widgets/dataset_tools.py @@ -7,7 +7,6 @@ from magicgui import magicgui from napari.layers import Labels, Image, Layer -from plantseg import PLANTSEG_MODELS_DIR from plantseg.dataset_tools.dataset_handler import DatasetHandler, save_dataset, load_dataset from plantseg.dataset_tools.dataset_handler import delete_dataset, change_dataset_location from plantseg.dataset_tools.images import Image as PlantSegImage diff --git a/plantseg/ui/widgets/training.py b/plantseg/ui/widgets/training.py index fb83a73e..24f172a3 100644 --- a/plantseg/ui/widgets/training.py +++ b/plantseg/ui/widgets/training.py @@ -6,13 +6,12 @@ from napari.types import LayerDataTuple from plantseg import PLANTSEG_MODELS_DIR +from plantseg.dataset_tools.dataset_handler import load_dataset from plantseg.training.train import unet_training from plantseg.ui.widgets.predictions import ALL_DEVICES from plantseg.ui.widgets.utils import create_layer_name, start_threading_process, return_value_if_widget from plantseg.utils import list_all_dimensionality from plantseg.utils import list_datasets -from plantseg.dataset_tools.dataset_handler import load_dataset -from plantseg.ui.logging import napari_formatted_logging empty_dataset = ['none'] startup_list_datasets = list_datasets() or empty_dataset @@ -127,4 +126,3 @@ def _on_sparse_change(sparse: bool): if startup_list_datasets[0] != empty_dataset[0]: _on_dataset_name_changed(startup_list_datasets[0]) - From cb61d4ca5e42abe7ee88a6b4d7c8e39f7a714887 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Tue, 15 Aug 2023 10:42:56 +0200 Subject: [PATCH 22/22] improves hints --- plantseg/dataset_tools/images.py | 2 +- plantseg/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plantseg/dataset_tools/images.py b/plantseg/dataset_tools/images.py index 2658762a..442c3d64 100644 --- a/plantseg/dataset_tools/images.py +++ b/plantseg/dataset_tools/images.py @@ -92,7 +92,7 @@ class GenericImage: data_type: str infos: tuple = None - def __init__(self, data: np.ndarray, + def __init__(self, data: Union[np.ndarray, MockData], spec: ImageSpecs): """ Generic image class to handle 2D and 3D images consistently. diff --git a/plantseg/utils.py b/plantseg/utils.py index 441a354a..d9b728c1 100644 --- a/plantseg/utils.py +++ b/plantseg/utils.py @@ -283,7 +283,7 @@ def clean_models(): "Are you sure you want to continue? (y/n) ") if answer == 'y': shutil.rmtree(PLANTSEG_LOCAL_DIR) - print("All models deleted... PlantSeg will now close") + print("All models/configs deleted... PlantSeg will now close") return None elif answer == 'n':