Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modeling multi-channel NCC-based registration #313

Open
neel-dey opened this issue May 21, 2021 · 5 comments · May be fixed by #314
Open

Modeling multi-channel NCC-based registration #313

neel-dey opened this issue May 21, 2021 · 5 comments · May be fixed by #314

Comments

@neel-dey
Copy link
Contributor

Hi Adrian & co.,

For multi-channel registration (eg, RGB image registration or 4D registration of subject A with T1 and T2 <---> subject B with T1 and T2), vxm implements 4D windows for local NCC (e.g., with window size [9, 9, 9, 2] for T1+T2).

I wonder if this may be a problem when dealing with domain shifts (eg, scanner differences) in a heterogeneous dataset. Typically, 3D NCC handles this by standardizing local statistics and is mostly insensitive to domain shift. However, T1 and T2 intensities may not change with the same transformation and this impacts the statistics of the 4D window.

In practice, when training for multi-channel templates on a dataset with multiple centers, the NCC loss values had high variance and depended strongly on the center (which eventually lead to divergence). This effect goes away once I just used two separate 3D NCC terms for each modality (ANTs uses separate NCC terms as well). I imagine that if the batch size is high enough, this would not be an issue, but we're stuck with a low number for 3D MRI. :)

Here's a minimal example demonstrating that 4D NCC is sensitive to domain shifts, whereas 3D NCC on each channel is relatively insensitive. The example uses ICBM 2009a Nonlinear Asymmetric T1+T2 as image 1 and NIH's pediatric template as image 2.

import numpy as np
import SimpleITK as sitk
import tensorflow as tf

from voxelmorph.tf.losses import NCC

ncc_object = NCC(win=[9, 9, 9], eps=1e-3)

# -----------------------------------------------------------------------------
# Utility functions

def load_images(fpath):
    img = sitk.GetArrayFromImage(sitk.ReadImage(fpath))
    return img


def stack_to_tf_tensor(arr1, arr2):
    arr = np.stack((arr1, arr2), axis=-1)  # 4D concatenate T1 and T2
    arr = arr[np.newaxis, ...]  # add batch axis
    return tf.convert_to_tensor(arr)


def scale_shift_clamp(arr, scale, shift):
    arr = scale*arr + shift  # linearly transform image intensities
    return np.maximum(arr, 0)


def ch(tfarr, dim):
    """Extract a channel from a (bs, x, y, z, ch) array.""" 
    return tfarr[..., dim, tf.newaxis]


# -----------------------------------------------------------------------------
# Load images

# Multimodal image 1:
adult_t1 = load_images('./adult/mni_icbm152_t1_tal_nlin_asym_09a.nii')
adult_t2 = load_images('./adult/mni_icbm152_t2_tal_nlin_asym_09a.nii')

adult = stack_to_tf_tensor(adult_t1, adult_t2)

# Multimodal image 2:
pediatric_t1 = load_images('./pediatric/nihpd_asym_04.5-18.5_t1w.nii')
pediatric_t2 = load_images('./pediatric/nihpd_asym_04.5-18.5_t2w.nii')

pediatric = stack_to_tf_tensor(pediatric_t1, pediatric_t2)

# -----------------------------------------------------------------------------
# Initial NCC

print('Original 4D NCC: {}'.format(ncc_object.loss(adult, pediatric)))


# -----------------------------------------------------------------------------
# Domain shift images

# Simulate 3 different domains/scanner pairs with arbitrary transforms: 
# Adult images:
adult_transform1 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 0.5, 10), 
    scale_shift_clamp(adult_t2, 1.3, 47),
)
adult_transform2 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 1.2, 16), 
    scale_shift_clamp(adult_t2, 0.4, 0),
)
adult_transform3 = stack_to_tf_tensor(
    scale_shift_clamp(adult_t1, 1.0, 20), 
    scale_shift_clamp(adult_t2, 2.0, 60),
)

# Pediatric images:
pediatric_transform1 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 0.9, 30), 
    scale_shift_clamp(pediatric_t2, 1.4, 3),
)
pediatric_transform2 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 2.0, 12), 
    scale_shift_clamp(pediatric_t2, 0.9, 0),
)
pediatric_transform3 = stack_to_tf_tensor(
    scale_shift_clamp(pediatric_t1, 0.8, 0), 
    scale_shift_clamp(pediatric_t2, 1.1, 0),
)


# -----------------------------------------------------------------------------
# Calculate 4D NCC between original images with new domain shifts

print('4D NCC domain 1: {}'.format(
    ncc_object.loss(adult_transform1, pediatric_transform1),
))
print('4D NCC domain 2: {}'.format(
    ncc_object.loss(adult_transform2, pediatric_transform2),
))
print('4D NCC domain 3: {}'.format(
    ncc_object.loss(adult_transform3, pediatric_transform3),
))


# -----------------------------------------------------------------------------
# Calculate 3D NCC_T1 + NCC_T2 between original images with new domain shifts

print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform1, 0), ch(pediatric_transform1, 0))
    + 0.5*ncc_object.loss(ch(adult_transform1, 1), ch(pediatric_transform1, 1)),
))
print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform2, 0), ch(pediatric_transform2, 0))
    + 0.5*ncc_object.loss(ch(adult_transform2, 1), ch(pediatric_transform2, 1)),
))
print('Split 3D NCC: {}'.format(
    0.5*ncc_object.loss(ch(adult_transform3, 0), ch(pediatric_transform3, 0))
    + 0.5*ncc_object.loss(ch(adult_transform3, 1), ch(pediatric_transform3, 1)),
))

This yields output:

Original NCC: [-0.6175599]
4D NCC domain 1: [-0.57820976]
4D NCC domain 2: [-0.9330061]
4D NCC domain 3: [-0.6425893]
Split 3D NCC domain 1: [-0.54893446]
Split 3D NCC domain 2: [-0.54617786]
Split 3D NCC domain 3: [-0.53752065]

Do you have any thoughts on this phenomenon and if 4D NCC would be better than split 3D NCC in other applications?

Thanks!

@adalca
Copy link
Collaborator

adalca commented May 22, 2021

Hi @neel-dey ,

Thanks, these are very good points. The multi-modal NCC was only experimentally added (and we should note this in the code), and we haven't really done thorough testing -- and I think your domain shift (counter)example is a good one. I suspect we could still implement a 'split' NCC without actually splitting the computation.

@brf2 and @ahoopes might be interested in this discussion as well.

@neel-dey
Copy link
Contributor Author

neel-dey commented May 23, 2021

Makes sense, thanks for the response. Implementing channel-wise NCC directly should just need two changes in https://github.com/voxelmorph/voxelmorph/blob/master/voxelmorph/tf/losses.py#L37 at L37 and L51, such that (changes marked with # CHANGED HERE):

class NCC:
    """
    Local (over window) normalized cross correlation loss.
    """

    def __init__(self, win=None, eps=1e-5):
        self.win = win
        self.eps = eps

    def ncc(self, I, J):
        # get dimension of volume
        # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
        ndims = len(I.get_shape().as_list()) - 2
        assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims

        # set window size
        if self.win is None:
            self.win = [9] * ndims

        # get convolution function
        conv_fn = getattr(tf.nn, 'conv%dd' % ndims)

        # compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J

        # compute filters
        in_ch = J.get_shape().as_list()[-1]
        sum_filt = tf.ones([*self.win, 1, in_ch])  # CHANGED HERE
        strides = 1
        if ndims > 1:
            strides = [1] * (ndims + 2)

        # compute local sums via convolution
        padding = 'SAME'
        I_sum = conv_fn(I, sum_filt, strides, padding)
        J_sum = conv_fn(J, sum_filt, strides, padding)
        I2_sum = conv_fn(I2, sum_filt, strides, padding)
        J2_sum = conv_fn(J2, sum_filt, strides, padding)
        IJ_sum = conv_fn(IJ, sum_filt, strides, padding)

        # compute cross correlation
        win_size = np.prod(self.win)  # CHANGED HERE
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size  # TODO: simplify this
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

        cc = cross * cross / (I_var * J_var + self.eps)

        # return mean cc for each entry in batch
        return tf.reduce_mean(K.batch_flatten(cc), axis=-1)

    def loss(self, y_true, y_pred):
        return - self.ncc(y_true, y_pred)

I haven't assessed this super carefully or critically, but numerical tests worked fine. If it looks correct, I can send a PR.

@adalca
Copy link
Collaborator

adalca commented May 23, 2021

Thanks so much @neel-dey this looks great. Feel free to do a PR and we'll evaluate.

@brf2 I think you were playing with mutli-channel NCC at some point, does this sound ok to you?

@neel-dey
Copy link
Contributor Author

Great, opened PR #314.

@brf2
Copy link
Collaborator

brf2 commented May 24, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants