Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for freezing pretrained vision model layers with regex #3981

Merged
merged 38 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
92366d5
added regex support for freezing specific layers
ethanreidel Mar 22, 2024
cbe1b67
fixed changes to trainer yaml config
ethanreidel Mar 22, 2024
5df7362
regen static schema
ethanreidel Mar 22, 2024
b7985f6
added trainer schema changes
ethanreidel Mar 26, 2024
3a4507d
fixed var names
ethanreidel Mar 26, 2024
9fe5df6
added unit test, cleaned up trainer code, added function in trainer_u…
ethanreidel Mar 27, 2024
82167ce
added training test
ethanreidel Mar 28, 2024
3b8bbd3
cleaned up tests
ethanreidel Mar 28, 2024
a7683de
misc comments/var name changes
ethanreidel Mar 28, 2024
5418a6a
updated description of layers_to_freeze_regex parameter
ethanreidel Mar 28, 2024
aeb2121
Revert "regen static schema"
ethanreidel Mar 28, 2024
1600d19
fixed typo
ethanreidel Mar 28, 2024
58d18ef
well another typo fix
ethanreidel Mar 28, 2024
f4e9cb4
initial summary CLI addition
ethanreidel Apr 2, 2024
192119b
removed try statement
ethanreidel Apr 2, 2024
6649433
added test and model list function
ethanreidel Apr 2, 2024
5ff7e7c
use pretrained off
ethanreidel Apr 2, 2024
ad77764
use_pretrained false for test
ethanreidel Apr 2, 2024
cc94eb4
added more thorough checking for valid regex
ethanreidel Apr 22, 2024
6f36aba
fixed train test, cleaned up pretrained summary CLI
ethanreidel Apr 22, 2024
31511fe
various nits fixed
ethanreidel Apr 25, 2024
a900e48
nit fixes
ethanreidel Apr 25, 2024
02bb963
Merge branch 'develop' of https://github.com/ethanreidel/ludwig into …
ethanreidel Apr 25, 2024
a375296
reverted changes to trainer utils
ethanreidel Apr 25, 2024
23a8b3c
updated collect summary + cli changes
ethanreidel May 16, 2024
1c0f173
post init changes + trainer cleanup
ethanreidel May 16, 2024
f96b5b3
updated unit test for LLM freezing
ethanreidel May 17, 2024
5546c83
two examples and various fixes
ethanreidel May 18, 2024
7483111
small fix
ethanreidel May 18, 2024
3210c56
fix
ethanreidel May 18, 2024
66fb3de
spaces fix
ethanreidel May 18, 2024
f8a46ca
added instructions for new functionality
ethanreidel May 18, 2024
8fff02e
quick fixes
ethanreidel May 19, 2024
d2e0690
small llm test changes
ethanreidel May 19, 2024
1feb853
added padding token for IT
ethanreidel May 19, 2024
0c5b762
cleaned up llm+ecd tests
ethanreidel May 23, 2024
4e62b92
rmtree files examples, fixed llm freezing unittest
ethanreidel May 23, 2024
e85d16d
remove files at end of example
ethanreidel May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self):
render_config Renders the fully populated config with all defaults set
check_install Runs a quick training run on synthetic data to verify installation status
upload Push trained model artifacts to a registry (e.g., Predibase, HuggingFace Hub)
pretrained_summary Displays a summary of pretrained model (e.g. alexnet, efficientnet)
""",
)
parser.add_argument("command", help="Subcommand to run")
Expand Down Expand Up @@ -191,6 +192,11 @@ def upload(self):

upload.cli(sys.argv[2:])

def pretrained_summary(self):
from ludwig.utils import pretrained_summary

pretrained_summary.cli_summarize_pretrained(sys.argv[2:])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be good to add in the docs some example runs and outputs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel To second @skanjila -- would it be possible to show an example of running this command in terms of how it is different from the existing one -- and an example output. Thank you very much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qq: when you say you'd like an example, do you mean an example in the Ludwig docs or how would you prefer it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel One option is to create an example in the examples/ top level directory in Ludwig


def main():
ludwig.contrib.preload(sys.argv)
Expand Down
9 changes: 9 additions & 0 deletions ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ ecd:
In many large-scale training runs, evaluation is often configured to run on
a sub-epoch time scale, or every few thousand steps.
ui_display_name: Checkpoints per epoch
layers_to_freeze_regex:
default_value_reasoning:
Default no layers will be frozen for fine-tuning a pretrained model.
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
description_implications:
Freezing specific layers can improve a pretrained models performance in a number
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
of ways. At a basic level, freezing early layers can prevent overfitting by retaining
more general features (beneficial for small datasets). Also can reduce computational
resource use and lower overall training time due to less gradient calculations.
expected_impact: 1
early_stop:
default_value_reasoning:
Deep learning models are prone to overfitting. It's generally
Expand Down
11 changes: 11 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ class BaseTrainerConfig(schema_utils.BaseMarshmallowConfig, ABC):
),
)

layers_to_freeze_regex: str = schema_utils.String(
default=None,
allow_none=True,
description=(
"Freeze specific layers based on provided regex. Freezing specific layers can improve a "
"pretrained models performance in a number of ways. At a basic level, freezing early layers can "
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
"prevent overfitting by retaining more general features (beneficial for small datasets). Also can "
"reduce computational resource use and lower overall training time due to less gradient calculations. "
),
)

early_stop: int = schema_utils.IntegerRange(
default=5,
min=-1,
Expand Down
6 changes: 6 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from ludwig.utils.torch_utils import get_torch_device
from ludwig.utils.trainer_utils import (
append_metrics,
freeze_layers_regex,
get_final_steps_per_checkpoint,
get_latest_metrics_dict,
get_new_progress_tracker,
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
self._validation_field = config.validation_field
self._validation_metric = config.validation_metric
self.early_stop = config.early_stop
self.layers_to_freeze_regex = config.layers_to_freeze_regex
self.steps_per_checkpoint = config.steps_per_checkpoint
self.checkpoints_per_epoch = config.checkpoints_per_epoch
self.evaluate_training_set = config.evaluate_training_set
Expand Down Expand Up @@ -225,6 +227,10 @@ def prepare(self):
base_learning_rate *= lr_scale_fn(self.distributed.size())
self.base_learning_rate = base_learning_rate

# Given that regex is supplied, freeze layers
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
if self.config.layers_to_freeze_regex:
freeze_layers_regex(self.config, self.model)

# We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes.
update_embedding_layer(self.compiled_model, self.config)

Expand Down
152 changes: 152 additions & 0 deletions ludwig/utils/pretrained_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#! /usr/bin/env python
# Copyright (c) 2024 Predibase, Inc., 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import importlib

from ludwig.api_annotations import DeveloperAPI
from ludwig.contrib import add_contrib_callback_args
from ludwig.globals import LUDWIG_VERSION
from ludwig.utils.print_utils import print_ludwig

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait are we really supporting all of these models, I thought we were just going to go out the door with a couple of models to start?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this specific feature (simple regex freezing), as long as you have access to the string representation of layers + actual model architecture, you can freeze any layers that you'd like. It wasn't any extra work adding support for all torchvision models besides adding to this list. I however don't like the look of this long model array though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel While it looks like for torchvision this will be supported, what about text/LLMs (this is kind of related to my previous comment in the Trainers section). Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this specific feature (simple regex freezing), as long as you have access to the string representation of layers + actual model architecture, you can freeze any layers that you'd like. It wasn't any extra work adding support for all torchvision models besides adding to this list. I however don't like the look of this long model array though

@ethanreidel Sorry, could you please point me to this "long model array"? Which line in your code has it? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your first question Alex: as long as access to the model layers + their requires_grad parameter is available, in theory, this feature should work on LLMs/text. I'm not too familiar with LLM architecture and I'll have to do some quick checks, but I'm 99% sure it is an easy addition. Second question: in a previous commit, I had a pretty hacky solution where users had another command line option (under pretrained_summary) which would list all available model names. Those names were stored in a Python list which had a few issues namely having to expand it regularly/many lines of unnecessary code. Saad made a good point and said to fully remove it (it was not needed), so it's no longer there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked and sure enough you can apply the same regex freezing technique to an LLM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked and sure enough you can apply the same regex freezing technique to an LLM
@ethanreidel That's awesome. Maybe we can then use one of my earlier comments to only add this parameter to the ECDTrainerConfig and FineTuneTrainerConfig for now.

As part of the examples you have, it would be good to create 2 example Python files:

  1. To show how to use it with a computer vision model
  2. To show how to use it with an LLM base model

What do you think?

models = [
ethanreidel marked this conversation as resolved.
Show resolved Hide resolved
"alexnet",
"convnext",
"convnext_base",
"convnext_large",
"convnext_small",
"convnext_tiny",
"densenet",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
"efficientnet",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_l",
"efficientnet_v2_m",
"efficientnet_v2_s",
"googlenet",
"inception",
"inception_v3",
"maxvit",
"maxvit_t",
"mnasnet",
"mnasnet0_5",
"mnasnet0_75",
"mnasnet1_0",
"mnasnet1_3",
"mobilenet",
"mobilenet_v2",
"mobilenet_v3_large",
"mobilenet_v3_small",
"mobilenetv2",
"mobilenetv3",
"regnet",
"regnet_x_16gf",
"regnet_x_1_6gf",
"regnet_x_32gf",
"regnet_x_3_2gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_8gf",
"regnet_y_128gf",
"regnet_y_16gf",
"regnet_y_1_6gf",
"regnet_y_32gf",
"regnet_y_3_2gf",
"regnet_y_400mf",
"regnet_y_800mf",
"regnet_y_8gf",
"resnet",
"resnet101",
"resnet152",
"resnet18",
"resnet34",
"resnet50",
"resnext101_32x8d",
"resnext101_64x4d",
"resnext50_32x4d",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
"shufflenetv2",
"squeezenet",
"squeezenet1_0",
"squeezenet1_1",
"swin_transformer",
"vgg",
"vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19",
"vgg19_bn",
"vit_b_16",
"vit_b_32",
"vit_h_14",
"vit_l_16",
"vit_l_32",
"wide_resnet101_2",
"wide_resnet50_2",
]


def pretrained_summary(model_name, **kwargs) -> None:
if model_name in models:
ethanreidel marked this conversation as resolved.
Show resolved Hide resolved
module = importlib.import_module("torchvision.models")
encoder_class = getattr(module, model_name)
model = encoder_class()

for name, _ in model.named_parameters():
print(name)
else:
print(f"No encoder found for '{model_name}'")


@DeveloperAPI
def cli_summarize_pretrained(sys_argv):
parser = argparse.ArgumentParser(
description="This script displays a summary of a pretrained model for freezing purposes.",
prog="ludwig pretrained_summary",
usage="%(prog)s [options]",
)
parser.add_argument("-m", "--model_name", help="output model layers", required=False, type=str)
parser.add_argument("-l", "--list_models", action="store_true", help="print available models")

add_contrib_callback_args(parser)
args = parser.parse_args(sys_argv)

args.callbacks = args.callbacks or []
for callback in args.callbacks:
callback.on_cmdline("pretrained_summary", *sys_argv)

print_ludwig("Model Summary", LUDWIG_VERSION)
if args.list_models:
print("Available models:")
for model in models:
print(f"- {model}")
else:
pretrained_summary(**vars(args))
15 changes: 15 additions & 0 deletions ludwig/utils/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING

Expand All @@ -10,6 +11,7 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import AUTO, COMBINED, LOSS
from ludwig.models.base import BaseModel
from ludwig.models.ecd import ECD
from ludwig.modules.metric_modules import get_best_function
from ludwig.utils.data_utils import save_json
from ludwig.utils.metric_utils import TrainerMetric
Expand Down Expand Up @@ -408,3 +410,16 @@ def get_rendered_batch_size_grad_accum(config: "BaseTrainerConfig", num_workers:
gradient_accumulation_steps = 1

return batch_size, gradient_accumulation_steps


def freeze_layers_regex(config: "BaseTrainerConfig", model: ECD) -> None:
"""Freezes layers based on provided regular expression."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add all of the comments around inputs/outputs as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel I think that if you put rom __future__ import annotations as the very first line in the module, you would not need to quote the types. Would you like to give it a try and see if it works? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the annotations import, and it worked, but the git pre-commit was forcing changes (e.g. converting all uppercase Dicts to lowercase dicts) that I didn't like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel I think that makes sense.

Are you able to expand on the docstring itself for this function? Also, if it also supports LLM, can we make model a union of ECD and LLM?

try:
pattern = re.compile(config.layers_to_freeze_regex)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanreidel Would you be interested if I gave you a reasonably well-featured RegEx utility so that you can just put it into the utils and use it -- it will save a lot of boilerplate like this. Please let me know any time. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds good. Thanks

except re.error:
logger.warning("Invalid regex input.\n")
ethanreidel marked this conversation as resolved.
Show resolved Hide resolved
exit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of exit(), let's raise a RuntimeError() with the same message.

In fact, here's a thought I have: We can move this check to earlier in the code path, that is, at config validation time. Specifically, you can create a __post_init__() hook for ECDTrainerConfig and FineTuneTrainerConfig that tries to do re.compile() and if it fails, throws a ConfigValidationError with the error message. That way, we don't have to wait for all of preprocessing etc to be done before catching this error.

Here's an example explaining the same idea in a different part of the Ludwig codepath: https://github.com/ludwig-ai/ludwig/blob/master/ludwig/schema/llms/peft.py#L443


for name, p in model.named_parameters():
if re.search(pattern, str(name)):
p.requires_grad = False
18 changes: 18 additions & 0 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,24 @@ def test_collect_summary_activations_weights_cli(tmpdir, csv_filename):
assert _run_ludwig("collect_summary", model=os.path.join(tmpdir, "experiment_run", "model"))


@pytest.mark.parametrize(
"model_name",
[
"alexnet",
"convnext_base",
"convnext_large",
"convnext_small",
"convnext_tiny",
"densenet121",
"densenet161",
"densenet169",
],
)
def test_pretrained_summary_cli(model_name: str):
"""Test pretrained_summary cli."""
_run_ludwig("pretrained_summary", model_name=model_name)


def test_synthesize_dataset_cli(tmpdir, csv_filename):
"""Test synthesize_data cli."""
# test depends on default setting of --dataset_size
Expand Down
70 changes: 70 additions & 0 deletions tests/ludwig/modules/test_regex_freezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import re
from contextlib import nullcontext as no_error_raised

import pytest

from ludwig.api import LudwigModel
from ludwig.constants import TRAINER
from ludwig.encoders.image.torchvision import TVEfficientNetEncoder
from ludwig.schema.trainer import BaseTrainerConfig
from ludwig.utils.misc_utils import set_random_seed
from ludwig.utils.trainer_utils import freeze_layers_regex
from tests.integration_tests.utils import category_feature, generate_data, image_feature

RANDOM_SEED = 130


@pytest.mark.parametrize(
"regex",
[
r"(features\.1.*|features\.2.*|features\.3.*|model\.features\.4\.1\.block\.3\.0\.weight)",
r"(features\.1.*|features\.2\.*|features\.3.*)",
r"(features\.4\.0\.block|features\.4\.\d+\.block)",
r"(features\.5\.*|features\.6\.*|features\.7\.*)",
r"(features\.8\.\d+\.weight|features\.8\.\d+\.bias)",
],
)
def test_tv_efficientnet_freezing(regex):
set_random_seed(RANDOM_SEED)

pretrained_model = TVEfficientNetEncoder(
model_variant="b0", use_pretrained=False, saved_weights_in_checkpoint=True, trainable=True
)

config = BaseTrainerConfig(layers_to_freeze_regex=regex)
freeze_layers_regex(config, pretrained_model)
for name, param in pretrained_model.named_parameters():
if re.search(re.compile(regex), name):
assert not param.requires_grad
else:
assert param.requires_grad


def test_frozen_tv_training(tmpdir, csv_filename):
input_features = [image_feature(tmpdir)]
output_features = [category_feature()]

config = {
"input_features": input_features,
"output_features": output_features,
TRAINER: {
"layers_to_freeze_regex": r"(features\.1.*|features\.2.*|model\.features\.4\.1\.block\.3\.0\.weight)",
"epochs": 1,
"train_steps": 1,
},
"encoder": {"type": "efficientnet", "use_pretrained": False},
}

training_data_csv_path = generate_data(config["input_features"], config["output_features"], csv_filename)
model = LudwigModel(config)

with no_error_raised():
model.experiment(
dataset=training_data_csv_path,
skip_save_training_description=True,
skip_save_training_statistics=True,
skip_save_model=True,
skip_save_progress=True,
skip_save_log=True,
skip_save_processed_input=True,
)