Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Gradual Unfreezing to mitigate catastrophic forgetting #3967

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

ethanreidel
Copy link
Contributor

@ethanreidel ethanreidel commented Mar 15, 2024

Adds the ability to gradually unfreeze or thaw specific layers within a pre-trained model's architecture. Aims to mitigate catastrophic forgetting/improve transfer learning capabilities. Currently works for ECD architecture.

User passes in two things:
thaw_epochs (list of integers) and layers_to_thaw (2D array of layer strings)

thaw_epochs:
-1
-2
layers_to_thaw:

  • ["features.0", "features.1"] (thaws these layers (weights+biases) at epoch 1)
  • ["features.2", "features.3"] (epoch 2)
    (keep in mind "features.0" will thaw all layers with the prefix "features.0" e.g. "features.0.1/2/3")

TODO/potential issues:

  • potentially change config syntax
  • users currently need to know the exact strings in architecture for thawing which is inconvenient
  • unittest iffy

test:
[tests/ludwig/modules/test_gradual_unfreezing.py]

Any and all feedback is greatly appreciated. 👍

Copy link

Unit Test Results

       6 files  ±       0         6 suites  ±0   52m 7s ⏱️ + 22m 2s
2 990 tests  -        3  2 966 ✔️  -      15  23 💤 +11  1 +1 
8 970 runs  +5 941  8 898 ✔️ +5 893  69 💤 +45  3 +3 

For more details on these failures, see this check.

Results for commit d2ba5cb. ± Comparison against base commit 606c732.

@ethanreidel
Copy link
Contributor Author

@skanjila @saad-palapa

@skanjila skanjila self-requested a review March 15, 2024 19:55

if len(self.thaw_epochs) != len(self.layers_to_thaw):
raise ValueError("The length of thaw_epochs and layers_to_thaw must be equal.")
self.layers = dict(zip(self.thaw_epochs, self.layers_to_thaw))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this epoch_to_layers

@@ -1029,7 +1036,12 @@ def train(
if profiler:
profiler.start()

current_epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use progress_tracker.epoch here?

self.config = config
self.model = model
self.thaw_epochs = self.config.thaw_epochs
self.layers_to_thaw = self.config.layers_to_thaw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a network has hundreds of layers, won't the config get unwieldy?

from ludwig.schema.gradual_unfreezer import GradualUnfreezerConfig


class GradualUnfreezer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were planning on doing something much simpler than this at first? Like a single regex to declare which layers to unfreeze.

In the transfer learning tutorials it trains the classification head until convergence (step 1) and then it unfreezes some of the encoder layers for a few more epochs of low learning rate training (step 2). Is this new functionality supposed to be part of step 2?

# Initialize gradual unfreezer
if self.config.gradual_unfreezer.thaw_epochs:
self.gradual_unfreezer = GradualUnfreezer(self.config.gradual_unfreezer, self.model)
logger.info(f"Gradual unfreezing for {len(self.gradual_unfreezer.thaw_epochs)} epoch(s)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more descriptive. Maybe something like:

Gradual unfreezing:

Epoch 10: unfreezing x1, x2, x3
Epoch 15: unfreezing x4, x5
Epoch 20: unfreezing x6
...


def thaw(self, current_epoch: int) -> None:
if current_epoch in self.layers:
current_layers = self.layers[current_epoch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to call this:
layers_to_thaw

if layer in str(name):
p.requires_grad_(True)
else:
raise ValueError("Layer type doesn't exist within model architecture")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the layer name to the error message


# thaw individual layer
def thawParameter(self, layer):
# is there a better way to do this instead of iterating through all parameters?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps in the init make a map of:
layer name => parameters

Only include the ones that will be thawed

self.thawParameter(layer)

# thaw individual layer
def thawParameter(self, layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this private


class GradualUnfreezer:
def __init__(self, config: GradualUnfreezerConfig, model):
self.config = config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this variable referenced outside of init?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants