Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing number of iterations per level of registration #2087

Open
0xJustin opened this issue Apr 7, 2024 · 1 comment
Open

Changing number of iterations per level of registration #2087

0xJustin opened this issue Apr 7, 2024 · 1 comment

Comments

@0xJustin
Copy link

0xJustin commented Apr 7, 2024

Describe the bug
I have a multiscale B-spline registration, and I have noticed that after the first handful of iterations of the levels with the highest number of gridpoints, the bsplines become overfit compared to my evaluation metric. I would like to change the number of iterations such that I specify different number of iterations for different levels. In the below code, I try to do this within a callback, but this does not change the actual optimizer that is running. Is there any practical way to do this?

import SimpleITK as sitk
best_transform = None
min_norm = 100

def command_iteration(method, bspline_transform) :
    if method.GetOptimizerIteration() == 0:
        # The BSpline is resized before the first optimizer
        # iteration is completed per level. Print the transform object
        # to show the adapted BSpline transform.
        print(bspline_transform)
    sitk_points = [tuple(point.astype(float)) for point in stack_coords_trunc]

    transformed_points = np.array([bspline_transform.TransformPoint(point) for point in sitk_points])
    # sitk_points = [tuple(point.astype(float)) for point in eval_points_post]

    # transformed_points6 = np.array([bspline_transform.TransformPoint(point) for point in sitk_points])
    norm = np.linalg.norm(transformed_points - eval_points_pre, axis=1).mean()
    if norm < globals()['min_norm']:
        globals()['best_transform'] = bspline_transform
        globals()['min_norm'] = norm
        print("New best norm")
    print(norm)
    print(registration_method.GetOptimizerLearningRate())
    # registration_method.SetOptimizerScales
    print("{0:3} = {1:10.5f}".format(method.GetOptimizerIteration(),
                                     method.GetMetricValue()))



def command_multi_iteration(method) :
    # The sitkMultiResolutionIterationEvent occurs before the
    # resolution of the transform. This event is used here to print
    # the status of the optimizer from the previous registration level.
    if registration_method.GetCurrentLevel() > 0:
        print("Optimizer stop condition: {0}".format(registration_method.GetOptimizerStopConditionDescription()))
        print(" Iteration: {0}".format(registration_method.GetOptimizerIteration()))
        print(" Metric value: {0}".format(registration_method.GetMetricValue()))
    print(registration_method.GetCurrentLevel())
    if registration_method.GetCurrentLevel() > 0:
        # Set number of iterations to 5
        print("Change Its")
        method.SetOptimizerAsGradientDescentLineSearch(numberOfIterations=5, convergenceMinimumValue=1e-6, convergenceWindowSize=5, estimateLearningRate=registration_method.Once, learningRate=200)


    print("--------- Resolution Changing ---------")

fixed_image = sitk.GetImageFromArray((np.transpose(static_reg_points_masked, axes=(2, 1, 0))).astype(np.float32))
moving_image = sitk.GetImageFromArray((np.transpose(moving_reg_points, axes=(2, 1, 0))).astype(np.float32) )

initial_transform = sitk.BSplineTransformInitializer(fixed_image, (2, 2, 2))
# initial_transform = sitk.ScaleTransform

# Create the registration method
registration_method = sitk.ImageRegistrationMethod()

# Set the initial transform

registration_method.SetInitialTransformAsBSpline(initial_transform, scaleFactors=[2,4,6])
registration_method.SetOptimizerAsGradientDescentLineSearch(numberOfIterations=150, convergenceMinimumValue=1e-4, convergenceWindowSize=5, estimateLearningRate=registration_method.EachIteration, learningRate=200)
# registration
# registration_method.Optimizer
# Multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [2,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[10,2,0])
# Set the similarity metric to correlation
registration_method.SetMetricAsCorrelation()

# Set the interpolator
registration_method.SetInterpolator(sitk.sitkLinear)

# Set the sampling percentage
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.AddCommand( sitk.sitkIterationEvent, lambda: command_iteration(registration_method, initial_transform) )
registration_method.AddCommand( sitk.sitkMultiResolutionIterationEvent, lambda: command_multi_iteration(registration_method) )


# Set the sampling strategy
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
final_transform = registration_method.Execute(fixed_image, moving_image)

@0xJustin
Copy link
Author

0xJustin commented Apr 7, 2024

Additionally, if there are examples of ways to regularize volumetric Bsplines as above within the Python interface, I'd be keen to know (besides just using fewer grid points). I know there are implementations in C++, but I don't think they have been integrated here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant