Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Sensitivity pruning not setting network weights to zero #539

Open
elliottloveridge opened this issue Sep 16, 2020 · 0 comments
Open

Sensitivity pruning not setting network weights to zero #539

elliottloveridge opened this issue Sep 16, 2020 · 0 comments

Comments

@elliottloveridge
Copy link

I'm implementing both sensitivity and sparsity level pruning for 3D networks.

Sensitivity analysis works as expected as can be seen from the figure below so I assume it is not a problem with using a 3D model for element-wise pruning methods.

sensitivity

However, when pruning during training the network weights are not going to zero.

My implementation can be abbreviated as the following:

# for each epoch
for i in range(opt.begin_epoch, opt.begin_epoch + opt.n_epochs):

        compression_scheduler.on_epoch_begin(i)

        model.train()

        # for each mini-batch
        for j, (inputs, targets) in enumerate(train_loader):

               compression_scheduler.on_minibatch_begin(i, minibatch_id=j, minibatches_per_epoch=len(train_loader))

                targets = targets.cuda()
                inputs = Variable(inputs)
                targets = Variable(targets)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # before backwards pass - update loss to include regularization
                compression_scheduler.before_backward_pass(i, minibatch_id=j, minibatches_per_epoch=len(train_loader), loss=loss)

                optimizer.zero_grad()
                loss.backward()

                compression_scheduler.before_parameter_optimization(i, minibatch_id=j, minibatches_per_epoch=len(train_loader), optimizer=optimizer)

                optimizer.step()

                compression_scheduler.on_minibatch_end(i, minibatch_id=j, minibatches_per_epoch=len(train_loader))

And an example compression file I'm using would be:

version: 1
pruners:
    mobilenetv2_fullyconnected:
        class: 'SensitivityPruner'
        sensitivities:
            'module.features.18.1.weight': 1.0
            'module.classifier.1.weight': 1.0

policies:
    - pruner:
        instance_name: 'mobilenetv2_fullyconnected'
    starting_epoch: 0
    ending_epoch: 20
    frequency: 1

Any help would be appreciated!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant