Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Gradual Unfreezing to mitigate catastrophic forgetting #3967

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 28 additions & 0 deletions ludwig/modules/gradual_unfreezer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ludwig.schema.gradual_unfreezer import GradualUnfreezerConfig


class GradualUnfreezer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were planning on doing something much simpler than this at first? Like a single regex to declare which layers to unfreeze.

In the transfer learning tutorials it trains the classification head until convergence (step 1) and then it unfreezes some of the encoder layers for a few more epochs of low learning rate training (step 2). Is this new functionality supposed to be part of step 2?

def __init__(self, config: GradualUnfreezerConfig, model):
self.config = config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this variable referenced outside of init?

self.model = model
self.thaw_epochs = self.config.thaw_epochs
self.layers_to_thaw = self.config.layers_to_thaw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a network has hundreds of layers, won't the config get unwieldy?


if len(self.thaw_epochs) != len(self.layers_to_thaw):
raise ValueError("The length of thaw_epochs and layers_to_thaw must be equal.")
self.layers = dict(zip(self.thaw_epochs, self.layers_to_thaw))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this epoch_to_layers


def thaw(self, current_epoch: int) -> None:
if current_epoch in self.layers:
current_layers = self.layers[current_epoch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to call this:
layers_to_thaw

for layer in current_layers:
self.thawParameter(layer)

# thaw individual layer
def thawParameter(self, layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this private

# is there a better way to do this instead of iterating through all parameters?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps in the init make a map of:
layer name => parameters

Only include the ones that will be thawed

for name, p in self.model.named_parameters():
if layer in str(name):
p.requires_grad_(True)
else:
raise ValueError("Layer type doesn't exist within model architecture")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the layer name to the error message

78 changes: 78 additions & 0 deletions ludwig/schema/gradual_unfreezer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC
from dataclasses import field
from typing import Dict

from marshmallow import fields, ValidationError

import ludwig.schema.utils as schema_utils
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MODEL_ECD
from ludwig.schema.metadata import TRAINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@ludwig_dataclass
class GradualUnfreezerConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Configuration for gradual unfreezing parameters."""

thaw_epochs: list = schema_utils.List(
int,
default=None,
description="Epochs to thaw at. For example, [1, 2, 3, 4] will thaw layers in layers_to_thaw 2D array",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradual_unfreezer"]["thaw_epochs"],
)

layers_to_thaw: list = schema_utils.List(
list,
inner_type=str,
default=None,
description="Individual layers to thaw at each thaw_epoch. 2D Array",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradual_unfreezer"]["layers_to_thaw"],
)


@DeveloperAPI
def GradualUnfreezerDataclassField(description: str, default: Dict = None):
allow_none = True
default = default or {}

class GradualUnfreezerMarshmallowField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
if value is None:
return value
if isinstance(value, dict):
try:
return GradualUnfreezerConfig.Schema().load(value)
except (TypeError, ValidationError) as e:
raise ValidationError(
f"Invalid params for gradual unfreezer: {value}, see GradualUnfreezerConfig class. Error: {e}"
)
raise ValidationError("Field should be None or dict")

def _jsonschema_type_mapping(self):
return {
**schema_utils.unload_jsonschema_from_marshmallow_class(GradualUnfreezerConfig),
"title": "gradual_unfreeze_options",
"description": description,
}

if not isinstance(default, dict):
raise ValidationError(f"Invalid default: `{default}`")

load_default = lambda: GradualUnfreezerConfig.Schema().load(default)
dump_default = GradualUnfreezerConfig.Schema().dump(default)

return field(
metadata={
"marshmallow_field": GradualUnfreezerMarshmallowField(
allow_none=allow_none,
load_default=load_default,
dump_default=dump_default,
metadata={
"description": description,
},
)
},
default_factory=load_default,
)
5 changes: 5 additions & 0 deletions ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,11 @@ ecd:
eta_min:
expected_impact: 1
ui_display_name: Eta Min
gradual_unfreezer:
thaw_epochs:
expected_impact: 1
layers_to_thaw:
expected_impact: 1
gbm:
learning_rate:
commonly_used: true
Expand Down
6 changes: 6 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ludwig.error import ConfigValidationError
from ludwig.schema import utils as schema_utils
from ludwig.schema.gradual_unfreezer import GradualUnfreezerConfig, GradualUnfreezerDataclassField
from ludwig.schema.lr_scheduler import LRSchedulerConfig, LRSchedulerDataclassField
from ludwig.schema.metadata import TRAINER_METADATA
from ludwig.schema.optimizers import (
Expand Down Expand Up @@ -177,6 +178,11 @@ def __post_init__(self):
],
)

gradual_unfreezer: GradualUnfreezerConfig = GradualUnfreezerDataclassField(
description="Parameter values for gradual unfreezer.",
default=None,
)

learning_rate_scheduler: LRSchedulerConfig = LRSchedulerDataclassField(
description="Parameter values for learning rate scheduler.",
default=None,
Expand Down
14 changes: 14 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from ludwig.models.ecd import ECD
from ludwig.models.llm import LLM
from ludwig.models.predictor import Predictor
from ludwig.modules.gradual_unfreezer import GradualUnfreezer
from ludwig.modules.lr_scheduler import LRScheduler
from ludwig.modules.metric_modules import get_improved_fn, get_initial_validation_value
from ludwig.modules.metric_registry import get_metric_objective
Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(
self.dist_model = None
self.optimizer = None
self.scheduler = None
self.gradual_unfreezer = None

self.prepare()

Expand Down Expand Up @@ -1002,6 +1004,11 @@ def train(
total_steps=self.total_steps,
)

# Initialize gradual unfreezer
if self.config.gradual_unfreezer.thaw_epochs:
self.gradual_unfreezer = GradualUnfreezer(self.config.gradual_unfreezer, self.model)
logger.info(f"Gradual unfreezing for {len(self.gradual_unfreezer.thaw_epochs)} epoch(s)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more descriptive. Maybe something like:

Gradual unfreezing:

Epoch 10: unfreezing x1, x2, x3
Epoch 15: unfreezing x4, x5
Epoch 20: unfreezing x6
...


if self.is_coordinator():
logger.info(
f"Training for {self.total_steps} step(s), approximately "
Expand Down Expand Up @@ -1029,7 +1036,12 @@ def train(
if profiler:
profiler.start()

current_epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use progress_tracker.epoch here?


while progress_tracker.steps < self.total_steps:
if self.gradual_unfreezer:
self.gradual_unfreezer.thaw(current_epoch)

# note that batch size may change over epochs
batcher.set_epoch(progress_tracker.epoch, progress_tracker.batch_size)

Expand Down Expand Up @@ -1086,6 +1098,8 @@ def train(
# Early stop if needed.
if should_break:
break

current_epoch += 1
finally:
# ================ Finished Training ================
self.callback(
Expand Down
77 changes: 77 additions & 0 deletions tests/ludwig/config_sampling/static_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -145285,6 +145285,83 @@
"title": "learning_rate_scheduler_options",
"type": "object"
},
"gradual_unfreezer": {
"additionalProperties": true,
"description": "Parameter values for gradual unfreezing.",
"properties": {
"thaw_epochs": {
"default": null,
"description": "List of epochs to unfreeze layers",
"items": {
"title": "thaw_epochs",
"type": "integer"
},
"parameter_metadata": {
"commonly_used": false,
"compute_tier": 0,
"default_value_reasoning": null,
"description_implications": null,
"example_value": null,
"expected_impact": 3,
"internal_only": false,
"literature_references": [
"https://aclanthology.org/P18-1031.pdf",
"https://arxiv.org/pdf/2301.05487.pdf"
],
"long_description": "",
"other_information": null,
"related_parameters": null,
"short_description": "",
"suggested_values": null,
"suggested_values_reasoning": null,
"ui_display_name": "Thaw Epochs"
},
"title": "thaw_epochs",
"type": [
"integer",
"null"
]
},
"layers_to_thaw": {
"default": [],
"description": "List of layers to thaw at each epoch",
"items": {
"type": "array",
"items": {
"type": "string",
"title": "layer_name"
}
},
"parameter_metadata": {
"commonly_used": false,
"compute_tier": 0,
"default_value_reasoning": null,
"description_implications": null,
"example_value": null,
"expected_impact": 3,
"internal_only": false,
"literature_references": [
"https://aclanthology.org/P18-1031.pdf",
"https://arxiv.org/pdf/2301.05487.pdf"
],
"long_description": "",
"other_information": null,
"related_parameters": null,
"short_description": "",
"suggested_values": null,
"suggested_values_reasoning": null,
"ui_display_name": "Layers To Thaw"
},
"title": "layers_to_thaw",
"type": [
"array",
"null"
]
}
},
"title": "gradual_unfreeze_options",
"type": "object"
},
"max_batch_size": {
"default": 1099511627776,
"description": "Auto batch size tuning and increasing batch size on plateau will be capped at this value. The default value is 2^40.",
Expand Down
27 changes: 27 additions & 0 deletions tests/ludwig/modules/test_gradual_unfreezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ludwig.encoders.image.torchvision import TVSwinTransformerEncoder
from ludwig.modules.gradual_unfreezer import GradualUnfreezer, GradualUnfreezerConfig
from ludwig.utils.misc_utils import set_random_seed


def test_gradual_unfreezer():
set_random_seed(13)

model = TVSwinTransformerEncoder(
model_variant="t",
use_pretrained=False,
saved_weights_in_checkpoint=True,
trainable=False,
)
config = GradualUnfreezerConfig(thaw_epochs=[1, 2], layers_to_thaw=[["features.0", "features.1"], ["features.2"]])

unfreezer = GradualUnfreezer(config=config, model=model)

for epoch in range(10):
unfreezer.thaw(epoch)

for name, p in model.named_parameters():
layer_to_thaw = any(layer in str(name) for layer_list in config.layers_to_thaw for layer in layer_list)
if layer_to_thaw:
assert p.requires_grad
else:
assert not p.requires_grad