Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: New SSL algorithm called SparK: Sparse and Hierarchical masKed modeling #1462

Open
Djoels opened this issue Dec 20, 2023 · 13 comments
Assignees

Comments

@Djoels
Copy link

Djoels commented Dec 20, 2023

It would be great if this new MAE-style method called SparK was introduced to lightly.

Paper: https://arxiv.org/abs/2301.03580 featured in ICLR'23 Spotlight
Code: https://github.com/keyu-tian/SparK

It was successfully applied to medical image applications, as documented in this Nature paper:
https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main

@philippmwirth
Copy link
Contributor

Hey @Djoels , thanks for bringing this up! We will take a look shortly and add it to the model tracker if relevant 🙂

@guarin
Copy link
Contributor

guarin commented Jan 2, 2024

This looks interesting indeed, thanks a lot for the issue! Added it to the methods tracker and will consider it for the paper session next week.

@johnsutor
Copy link
Contributor

I can take this issue.

@guarin
Copy link
Contributor

guarin commented Jan 22, 2024

Thanks for looking into this @johnsutor! The original codebase implements the sparse net in a quite hacky way (see code here) and I was wondering whether it would be possible to pass the masks explicitly to the forward function instead of assigning them to a global variable. Maybe this would be interesting to explore, wdyt?

@johnsutor
Copy link
Contributor

I'll investigate and get back to you!

@johnsutor
Copy link
Contributor

johnsutor commented Jan 28, 2024

Seems fairly straightforward to achieve based on https://github.com/keyu-tian/SparK/tree/main/pretrain#regarding-sparse-convolution. I don't mind giving it a stab, my thoughts are to implement the encoder and decoder from their code base (https://github.com/keyu-tian/SparK/tree/main/pretrain) within https://github.com/lightly-ai/lightly/tree/master/lightly/models, just naming the file something like spark.py, if this sounds good I'll give it a go.

@guarin
Copy link
Contributor

guarin commented Jan 29, 2024

Sounds good! Thanks a lot for looking into it.

Maybe create a lightly/models/sparse subdirectory and put it there. You could even name the file sparse_resnet.py. And it would be create if you could keep the same structure as the original resnet in torchvision. Then it would be easy to convert from sparse resnet to dense resnet and vice-versa.

@johnsutor
Copy link
Contributor

I went ahead and implemented a resnet compatible with the standard torchvision library, so that we don't have to add timm as a dependency.

Furthermore, I achieved passing the mask at runtime without setting a global variable using a pre-forward hook. This is how it looks so far:

class SparseEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_size: int, sync_bn: bool = False):
        """Sparse Encoder as used by SparK [0]

        Default params are the ones explained in the original code base
        [0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580

        Attributes:
            backbone:
                Backbone model to extract features from images. Should have both
                the methods get_downsample_ratio() and get_feature_map_channels()
                implemented.
            input_size:
                Size of the input image.
            sync_bn:
                Whether or not to use Sync Batch Norm in this model.

        """
        super(SparseEncoder, self).__init__()
        self.mask: torch.Tensor
        self.sp_backbone = self.dense_model_to_sparse(m=backbone, sbn=sbn)
        self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
            input_size,
            backbone.get_downsample_ratio(),
            backbone.get_feature_map_channels(),
        )

    def mask_hook(
        self, module: nn.Module, input: Tuple[torch.Tensor], output: Tuple[torch.Tensor]
    ):
        input = (input[0], self.mask)
        return input

    def dense_model_to_sparse(self, m: nn.Module, sbn: bool = False):
        oup = m
        if isinstance(m, nn.Conv2d):
            m: nn.Conv2d
            bias = m.bias is not None
            oup = SparseConv2d(
                m.in_channels,
                m.out_channels,
                kernel_size=m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                groups=m.groups,
                bias=bias,
                padding_mode=m.padding_mode,
            )
            oup.weight.data.copy_(m.weight.data)
            if bias:
                oup.bias.data.copy_(m.bias.data)
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, nn.MaxPool2d):
            m: nn.MaxPool2d
            oup = SparseMaxPooling(
                m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                return_indices=m.return_indices,
                ceil_mode=m.ceil_mode,
            )
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, nn.AvgPool2d):
            m: nn.AvgPool2d
            oup = SparseAvgPooling(
                m.kernel_size,
                m.stride,
                m.padding,
                ceil_mode=m.ceil_mode,
                count_include_pad=m.count_include_pad,
                divisor_override=m.divisor_override,
            )
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            m: nn.BatchNorm2d
            oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(
                m.weight.shape[0],
                eps=m.eps,
                momentum=m.momentum,
                affine=m.affine,
                track_running_stats=m.track_running_stats,
            )
            oup.weight.data.copy_(m.weight.data)
            oup.bias.data.copy_(m.bias.data)
            oup.running_mean.data.copy_(m.running_mean.data)
            oup.running_var.data.copy_(m.running_var.data)
            oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
            if hasattr(m, "qconfig"):
                oup.qconfig = m.qconfig
            oup.register_forward_pre_hook(self.mask_hook)
        elif isinstance(m, (nn.Conv1d,)):
            raise NotImplementedError

        for name, child in m.named_children():
            oup.add_module(name, self.dense_model_to_sparse(child, sbn=sbn))
        del m
        oup.register_forward_pre_hook(self.mask_hook)
        return oup

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        assert (
            mask is not None or self.mask is not None
        ), "Mask must be supplied for training"
        self.mask = mask
        return self.sp_backbone(x, hierarchical=True)

if that works, I'll go ahead and implement the Spark Module as well. The one thing I'm thinking about altering there is configuring the forward pass to return the reconstructions only, and perhaps create a separate method for calculating the reconstruction loss. This is to keep the code similar to the masked auto encoder.

@guarin
Copy link
Contributor

guarin commented Feb 1, 2024

Oh wow, thanks a lot for looking into this! It looks really good!

I have some comments/questions:

  • While hooks are a very elegant solution, they are not very fast. My concern is that calling a hook before every module could slow down training considerably. I drafted an alternative that doesn't rely on hooks below. Let me know what you think about it.
  • Is the input_size parameter needed? The user has to pass the mask anyways, can we not infer the size of it?
  • I don't think torchvision models have a get_downsample_ratio and get_feature_map_channels method that we can call. I think we should be able to calculate those within the sparse modules and adapt the mask accordingly.
  • I believe we can drop the sync_bn parameter as we can check for the module type when we make the conversion (see code below)
  • AFAIK parameter.data.copy_ is deprecated and parameter.copy_ should be used instead (changed this already in the code below). I was actually wondering whether we even need to copy the parameters, can we not just assign them with oupt.weight = m.weight etc?

Here is the draft for a version that doesn't use hooks. Instead, it saves a SparseMask object on all modules that need access to the mask. The modules can then modify this mask in their forward pass. As the object is shared across all modules they'll all have access to it. I also moved the dense_model_to_sparse function outside of the SparseEncoder class as it doesn't really need access to the class. This would also make it easier to reuse the method in other modules.

class SparseMask:
    def __init__(self):
        self.mask: Union[Tensor, None] = None


class SparseEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_size: int):
        """Sparse Encoder as used by SparK [0]

        Default params are the ones explained in the original code base
        [0] Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling https://arxiv.org/abs/2301.03580

        Attributes:
            backbone:
                Backbone model to extract features from images. Should have both
                the methods get_downsample_ratio() and get_feature_map_channels()
                implemented.
            input_size:
                Size of the input image.

        """
        super().__init__()
        self.sparse_mask = SparseMask()
        self.sparse_backbone = self.dense_model_to_sparse(
            m=backbone,
            mask=self.sparse_mask
        )
        self.input_size, self.downsample_raito, self.enc_feat_map_chs = (
            input_size,
            backbone.get_downsample_ratio(),
            backbone.get_feature_map_channels(),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        # All submodules will now have access to the sparse mask
        self.sparse_mask.mask = mask
        return self.sp_backbone(x, hierarchical=True)


def dense_model_to_sparse(m: Module, sparse_mask: SparseMask) -> Module:
    oup = m
    if isinstance(m, nn.Conv2d):
        m: nn.Conv2d
        bias = m.bias is not None
        oup = SparseConv2d(
            m.in_channels,
            m.out_channels,
            kernel_size=m.kernel_size,
            stride=m.stride,
            padding=m.padding,
            dilation=m.dilation,
            groups=m.groups,
            bias=bias,
            padding_mode=m.padding_mode,
            sparse_mask=sparse_mask,
        )
        oup.weight.copy_(m.weight)
        if bias:
            oup.bias.copy_(m.bias)
    elif isinstance(m, nn.MaxPool2d):
        m: nn.MaxPool2d
        oup = SparseMaxPooling(
            m.kernel_size,
            stride=m.stride,
            padding=m.padding,
            dilation=m.dilation,
            return_indices=m.return_indices,
            ceil_mode=m.ceil_mode,
            sparse_mask=sparse_mask,
        )
    elif isinstance(m, nn.AvgPool2d):
        m: nn.AvgPool2d
        oup = SparseAvgPooling(
            m.kernel_size,
            m.stride,
            m.padding,
            ceil_mode=m.ceil_mode,
            count_include_pad=m.count_include_pad,
            divisor_override=m.divisor_override,
            sparse_mask=sparse_mask,
        )
    elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
        m: nn.BatchNorm2d
        oup = (SparseSyncBatchNorm2d if isinstance(m, nn.SyncBatchNorm) else SparseBatchNorm2d)(
            m.weight.shape[0],
            eps=m.eps,
            momentum=m.momentum,
            affine=m.affine,
            track_running_stats=m.track_running_stats,
            sparse_mask=sparse_mask,
        )
        oup.weight.copy_(m.weight)
        oup.bias.copy_(m.bias)
        oup.running_mean.copy_(m.running_mean)
        oup.running_var.copy_(m.running_var)
        oup.num_batches_tracked.copy_(m.num_batches_tracked)
        if hasattr(m, "qconfig"):
            oup.qconfig = m.qconfig
    elif isinstance(m, (nn.Conv1d,)):
        raise NotImplementedError

    for name, child in m.named_children():
        oup.add_module(name, dense_model_to_sparse(child, sparse_mask=sparse_mask))
    del m
    return oup

@johnsutor
Copy link
Contributor

johnsutor commented Feb 1, 2024

Hey, thanks for checking it out! In regards to your bullets:

  1. I think that approach of storing the sparse mask on the modules works (this was going to be my second approach if we didn't use hooks, so that works well for me haha).
  2. The input size is used with the SparK module to determine which channels should get masked when reshaping the image tensor, though not for the encoder. I'll remove it from the encoder. As for the SparK module itself, we could determine input size on the fly, though I'm not sure if this would have adverse effects on the training procedure if different batches have different spatial dimensions. It's only ultimately used to calculate the number of channels to keep from flattened representations, so your call on whether or not to define the input size up front to enforce consistent spatial dimensions or not.
  3. With the prototype spark-compatible resnet that I have in the works, I calculate the feature map channels like so:
        with torch.no_grad():
            self._feature_map_channels = []
            x = self.layer1(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer2(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer3(x)
            self._feature_map_channels.append(x.shape[1])
            x = self.layer4(x)
            self._feature_map_channels.append(x.shape[1])

Perhaps for a more general purpose feature extractor that should work with all modules, we can determine the resolution of the feature map by calling create_feature_extractor during initialization and comparing the feature map size to the input size. Or, we can call get_graph_node_names, and returning the intermediate output up until the final linear pooling and linear layer. This should work with most modules
4. Sounds good!
5. That's fine by me! I tried to leave the code as similar as possible to avoid breaking anything, but I doubt that change will alter anything.

@guarin
Copy link
Contributor

guarin commented Feb 1, 2024

  1. Haha perfect!
  2. Sounds good :)
  3. I am wondering whether the feature map channels have to be known in advance. Are they used for anything else than for the mask resizing? I imagine we could calculate the size on the fly in the forward pass of the Sparse modules. Something along the lines of this:
class SparseConv2d(Conv2d):
    def forward(x: Tensor) -> Tensor:
         x = super().forward(x)
         mask = get_mask_with_size(self.sparse_mask, x)
         x = apply_mask(x, mask)
        return x

@johnsutor
Copy link
Contributor

johnsutor commented Feb 2, 2024

The feature map channels are used in step three of the forward process, where the hierarchical dense features are calculated for decoding. When the SparK module is created, it creates a mask token and a densify norm layer for when it fills in the masked locations with the mask token. We can circumvent the norm issue using a lazy batch normalization, and perhaps for the mask token itself, we can create it on the fly from the first pass right before this line?

@johnsutor
Copy link
Contributor

johnsutor commented Feb 9, 2024

Update: been busy with other life requirements, I'll get back to it when I can. If you want, I can commit the code that I've been working on

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

No branches or pull requests

4 participants