-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Gradual Unfreezing to mitigate catastrophic forgetting #3967
base: master
Are you sure you want to change the base?
Changes from all commits
a00660e
884f121
f6f6d39
db6a78a
d2ba5cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from ludwig.schema.gradual_unfreezer import GradualUnfreezerConfig | ||
|
||
|
||
class GradualUnfreezer: | ||
def __init__(self, config: GradualUnfreezerConfig, model): | ||
self.config = config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this variable referenced outside of init? |
||
self.model = model | ||
self.thaw_epochs = self.config.thaw_epochs | ||
self.layers_to_thaw = self.config.layers_to_thaw | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a network has hundreds of layers, won't the config get unwieldy? |
||
|
||
if len(self.thaw_epochs) != len(self.layers_to_thaw): | ||
raise ValueError("The length of thaw_epochs and layers_to_thaw must be equal.") | ||
self.layers = dict(zip(self.thaw_epochs, self.layers_to_thaw)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you call this |
||
|
||
def thaw(self, current_epoch: int) -> None: | ||
if current_epoch in self.layers: | ||
current_layers = self.layers[current_epoch] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to call this: |
||
for layer in current_layers: | ||
self.thawParameter(layer) | ||
|
||
# thaw individual layer | ||
def thawParameter(self, layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this private |
||
# is there a better way to do this instead of iterating through all parameters? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps in the init make a map of: Only include the ones that will be thawed |
||
for name, p in self.model.named_parameters(): | ||
if layer in str(name): | ||
p.requires_grad_(True) | ||
else: | ||
raise ValueError("Layer type doesn't exist within model architecture") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add the layer name to the error message |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from abc import ABC | ||
from dataclasses import field | ||
from typing import Dict | ||
|
||
from marshmallow import fields, ValidationError | ||
|
||
import ludwig.schema.utils as schema_utils | ||
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.constants import MODEL_ECD | ||
from ludwig.schema.metadata import TRAINER_METADATA | ||
from ludwig.schema.utils import ludwig_dataclass | ||
|
||
|
||
@DeveloperAPI | ||
@ludwig_dataclass | ||
class GradualUnfreezerConfig(schema_utils.BaseMarshmallowConfig, ABC): | ||
"""Configuration for gradual unfreezing parameters.""" | ||
|
||
thaw_epochs: list = schema_utils.List( | ||
int, | ||
default=None, | ||
description="Epochs to thaw at. For example, [1, 2, 3, 4] will thaw layers in layers_to_thaw 2D array", | ||
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradual_unfreezer"]["thaw_epochs"], | ||
) | ||
|
||
layers_to_thaw: list = schema_utils.List( | ||
list, | ||
inner_type=str, | ||
default=None, | ||
description="Individual layers to thaw at each thaw_epoch. 2D Array", | ||
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradual_unfreezer"]["layers_to_thaw"], | ||
) | ||
|
||
|
||
@DeveloperAPI | ||
def GradualUnfreezerDataclassField(description: str, default: Dict = None): | ||
allow_none = True | ||
default = default or {} | ||
|
||
class GradualUnfreezerMarshmallowField(fields.Field): | ||
def _deserialize(self, value, attr, data, **kwargs): | ||
if value is None: | ||
return value | ||
if isinstance(value, dict): | ||
try: | ||
return GradualUnfreezerConfig.Schema().load(value) | ||
except (TypeError, ValidationError) as e: | ||
raise ValidationError( | ||
f"Invalid params for gradual unfreezer: {value}, see GradualUnfreezerConfig class. Error: {e}" | ||
) | ||
raise ValidationError("Field should be None or dict") | ||
|
||
def _jsonschema_type_mapping(self): | ||
return { | ||
**schema_utils.unload_jsonschema_from_marshmallow_class(GradualUnfreezerConfig), | ||
"title": "gradual_unfreeze_options", | ||
"description": description, | ||
} | ||
|
||
if not isinstance(default, dict): | ||
raise ValidationError(f"Invalid default: `{default}`") | ||
|
||
load_default = lambda: GradualUnfreezerConfig.Schema().load(default) | ||
dump_default = GradualUnfreezerConfig.Schema().dump(default) | ||
|
||
return field( | ||
metadata={ | ||
"marshmallow_field": GradualUnfreezerMarshmallowField( | ||
allow_none=allow_none, | ||
load_default=load_default, | ||
dump_default=dump_default, | ||
metadata={ | ||
"description": description, | ||
}, | ||
) | ||
}, | ||
default_factory=load_default, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ | |
from ludwig.models.ecd import ECD | ||
from ludwig.models.llm import LLM | ||
from ludwig.models.predictor import Predictor | ||
from ludwig.modules.gradual_unfreezer import GradualUnfreezer | ||
from ludwig.modules.lr_scheduler import LRScheduler | ||
from ludwig.modules.metric_modules import get_improved_fn, get_initial_validation_value | ||
from ludwig.modules.metric_registry import get_metric_objective | ||
|
@@ -215,6 +216,7 @@ def __init__( | |
self.dist_model = None | ||
self.optimizer = None | ||
self.scheduler = None | ||
self.gradual_unfreezer = None | ||
|
||
self.prepare() | ||
|
||
|
@@ -1002,6 +1004,11 @@ def train( | |
total_steps=self.total_steps, | ||
) | ||
|
||
# Initialize gradual unfreezer | ||
if self.config.gradual_unfreezer.thaw_epochs: | ||
self.gradual_unfreezer = GradualUnfreezer(self.config.gradual_unfreezer, self.model) | ||
logger.info(f"Gradual unfreezing for {len(self.gradual_unfreezer.thaw_epochs)} epoch(s)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make this more descriptive. Maybe something like:
|
||
|
||
if self.is_coordinator(): | ||
logger.info( | ||
f"Training for {self.total_steps} step(s), approximately " | ||
|
@@ -1029,7 +1036,12 @@ def train( | |
if profiler: | ||
profiler.start() | ||
|
||
current_epoch = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't we use |
||
|
||
while progress_tracker.steps < self.total_steps: | ||
if self.gradual_unfreezer: | ||
self.gradual_unfreezer.thaw(current_epoch) | ||
|
||
# note that batch size may change over epochs | ||
batcher.set_epoch(progress_tracker.epoch, progress_tracker.batch_size) | ||
|
||
|
@@ -1086,6 +1098,8 @@ def train( | |
# Early stop if needed. | ||
if should_break: | ||
break | ||
|
||
current_epoch += 1 | ||
finally: | ||
# ================ Finished Training ================ | ||
self.callback( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from ludwig.encoders.image.torchvision import TVSwinTransformerEncoder | ||
from ludwig.modules.gradual_unfreezer import GradualUnfreezer, GradualUnfreezerConfig | ||
from ludwig.utils.misc_utils import set_random_seed | ||
|
||
|
||
def test_gradual_unfreezer(): | ||
set_random_seed(13) | ||
|
||
model = TVSwinTransformerEncoder( | ||
model_variant="t", | ||
use_pretrained=False, | ||
saved_weights_in_checkpoint=True, | ||
trainable=False, | ||
) | ||
config = GradualUnfreezerConfig(thaw_epochs=[1, 2], layers_to_thaw=[["features.0", "features.1"], ["features.2"]]) | ||
|
||
unfreezer = GradualUnfreezer(config=config, model=model) | ||
|
||
for epoch in range(10): | ||
unfreezer.thaw(epoch) | ||
|
||
for name, p in model.named_parameters(): | ||
layer_to_thaw = any(layer in str(name) for layer_list in config.layers_to_thaw for layer in layer_list) | ||
if layer_to_thaw: | ||
assert p.requires_grad | ||
else: | ||
assert not p.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we were planning on doing something much simpler than this at first? Like a single regex to declare which layers to unfreeze.
In the transfer learning tutorials it trains the classification head until convergence (step 1) and then it unfreezes some of the encoder layers for a few more epochs of low learning rate training (step 2). Is this new functionality supposed to be part of step 2?