-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for freezing pretrained vision model layers with regex #3981
Changes from 18 commits
92366d5
cbe1b67
5df7362
b7985f6
3a4507d
9fe5df6
82167ce
3b8bbd3
a7683de
5418a6a
aeb2121
1600d19
58d18ef
f4e9cb4
192119b
6649433
5ff7e7c
ad77764
cc94eb4
6f36aba
31511fe
a900e48
02bb963
a375296
23a8b3c
1c0f173
f96b5b3
5546c83
7483111
3210c56
66fb3de
f8a46ca
8fff02e
d2e0690
1feb853
0c5b762
4e62b92
e85d16d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#! /usr/bin/env python | ||
# Copyright (c) 2024 Predibase, Inc., 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import argparse | ||
import importlib | ||
|
||
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.contrib import add_contrib_callback_args | ||
from ludwig.globals import LUDWIG_VERSION | ||
from ludwig.utils.print_utils import print_ludwig | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait are we really supporting all of these models, I thought we were just going to go out the door with a couple of models to start? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this specific feature (simple regex freezing), as long as you have access to the string representation of layers + actual model architecture, you can freeze any layers that you'd like. It wasn't any extra work adding support for all torchvision models besides adding to this list. I however don't like the look of this long model array though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ethanreidel While it looks like for torchvision this will be supported, what about text/LLMs (this is kind of related to my previous comment in the Trainers section). Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@ethanreidel Sorry, could you please point me to this "long model array"? Which line in your code has it? Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For your first question Alex: as long as access to the model layers + their requires_grad parameter is available, in theory, this feature should work on LLMs/text. I'm not too familiar with LLM architecture and I'll have to do some quick checks, but I'm 99% sure it is an easy addition. Second question: in a previous commit, I had a pretty hacky solution where users had another command line option (under pretrained_summary) which would list all available model names. Those names were stored in a Python list which had a few issues namely having to expand it regularly/many lines of unnecessary code. Saad made a good point and said to fully remove it (it was not needed), so it's no longer there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checked and sure enough you can apply the same regex freezing technique to an LLM There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As part of the examples you have, it would be good to create 2 example Python files:
What do you think? |
||
models = [ | ||
ethanreidel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"alexnet", | ||
"convnext", | ||
"convnext_base", | ||
"convnext_large", | ||
"convnext_small", | ||
"convnext_tiny", | ||
"densenet", | ||
"densenet121", | ||
"densenet161", | ||
"densenet169", | ||
"densenet201", | ||
"efficientnet", | ||
"efficientnet_b0", | ||
"efficientnet_b1", | ||
"efficientnet_b2", | ||
"efficientnet_b3", | ||
"efficientnet_b4", | ||
"efficientnet_b5", | ||
"efficientnet_b6", | ||
"efficientnet_b7", | ||
"efficientnet_v2_l", | ||
"efficientnet_v2_m", | ||
"efficientnet_v2_s", | ||
"googlenet", | ||
"inception", | ||
"inception_v3", | ||
"maxvit", | ||
"maxvit_t", | ||
"mnasnet", | ||
"mnasnet0_5", | ||
"mnasnet0_75", | ||
"mnasnet1_0", | ||
"mnasnet1_3", | ||
"mobilenet", | ||
"mobilenet_v2", | ||
"mobilenet_v3_large", | ||
"mobilenet_v3_small", | ||
"mobilenetv2", | ||
"mobilenetv3", | ||
"regnet", | ||
"regnet_x_16gf", | ||
"regnet_x_1_6gf", | ||
"regnet_x_32gf", | ||
"regnet_x_3_2gf", | ||
"regnet_x_400mf", | ||
"regnet_x_800mf", | ||
"regnet_x_8gf", | ||
"regnet_y_128gf", | ||
"regnet_y_16gf", | ||
"regnet_y_1_6gf", | ||
"regnet_y_32gf", | ||
"regnet_y_3_2gf", | ||
"regnet_y_400mf", | ||
"regnet_y_800mf", | ||
"regnet_y_8gf", | ||
"resnet", | ||
"resnet101", | ||
"resnet152", | ||
"resnet18", | ||
"resnet34", | ||
"resnet50", | ||
"resnext101_32x8d", | ||
"resnext101_64x4d", | ||
"resnext50_32x4d", | ||
"shufflenet_v2_x0_5", | ||
"shufflenet_v2_x1_0", | ||
"shufflenet_v2_x1_5", | ||
"shufflenet_v2_x2_0", | ||
"shufflenetv2", | ||
"squeezenet", | ||
"squeezenet1_0", | ||
"squeezenet1_1", | ||
"swin_transformer", | ||
"vgg", | ||
"vgg11", | ||
"vgg11_bn", | ||
"vgg13", | ||
"vgg13_bn", | ||
"vgg16", | ||
"vgg16_bn", | ||
"vgg19", | ||
"vgg19_bn", | ||
"vit_b_16", | ||
"vit_b_32", | ||
"vit_h_14", | ||
"vit_l_16", | ||
"vit_l_32", | ||
"wide_resnet101_2", | ||
"wide_resnet50_2", | ||
] | ||
|
||
|
||
def pretrained_summary(model_name, **kwargs) -> None: | ||
if model_name in models: | ||
ethanreidel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
module = importlib.import_module("torchvision.models") | ||
encoder_class = getattr(module, model_name) | ||
model = encoder_class() | ||
|
||
for name, _ in model.named_parameters(): | ||
print(name) | ||
else: | ||
print(f"No encoder found for '{model_name}'") | ||
|
||
|
||
@DeveloperAPI | ||
def cli_summarize_pretrained(sys_argv): | ||
parser = argparse.ArgumentParser( | ||
description="This script displays a summary of a pretrained model for freezing purposes.", | ||
prog="ludwig pretrained_summary", | ||
usage="%(prog)s [options]", | ||
) | ||
parser.add_argument("-m", "--model_name", help="output model layers", required=False, type=str) | ||
parser.add_argument("-l", "--list_models", action="store_true", help="print available models") | ||
|
||
add_contrib_callback_args(parser) | ||
args = parser.parse_args(sys_argv) | ||
|
||
args.callbacks = args.callbacks or [] | ||
for callback in args.callbacks: | ||
callback.on_cmdline("pretrained_summary", *sys_argv) | ||
|
||
print_ludwig("Model Summary", LUDWIG_VERSION) | ||
if args.list_models: | ||
print("Available models:") | ||
for model in models: | ||
print(f"- {model}") | ||
else: | ||
pretrained_summary(**vars(args)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import logging | ||
import re | ||
from collections import defaultdict | ||
from typing import Dict, List, Tuple, TYPE_CHECKING | ||
|
||
|
@@ -10,6 +11,7 @@ | |
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.constants import AUTO, COMBINED, LOSS | ||
from ludwig.models.base import BaseModel | ||
from ludwig.models.ecd import ECD | ||
from ludwig.modules.metric_modules import get_best_function | ||
from ludwig.utils.data_utils import save_json | ||
from ludwig.utils.metric_utils import TrainerMetric | ||
|
@@ -408,3 +410,16 @@ def get_rendered_batch_size_grad_accum(config: "BaseTrainerConfig", num_workers: | |
gradient_accumulation_steps = 1 | ||
|
||
return batch_size, gradient_accumulation_steps | ||
|
||
|
||
def freeze_layers_regex(config: "BaseTrainerConfig", model: ECD) -> None: | ||
"""Freezes layers based on provided regular expression.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets add all of the comments around inputs/outputs as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ethanreidel I think that if you put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested the annotations import, and it worked, but the git pre-commit was forcing changes (e.g. converting all uppercase Dicts to lowercase dicts) that I didn't like. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ethanreidel I think that makes sense. Are you able to expand on the docstring itself for this function? Also, if it also supports |
||
try: | ||
pattern = re.compile(config.layers_to_freeze_regex) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ethanreidel Would you be interested if I gave you a reasonably well-featured RegEx utility so that you can just put it into the utils and use it -- it will save a lot of boilerplate like this. Please let me know any time. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that sounds good. Thanks |
||
except re.error: | ||
logger.warning("Invalid regex input.\n") | ||
ethanreidel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
exit() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of In fact, here's a thought I have: We can move this check to earlier in the code path, that is, at config validation time. Specifically, you can create a Here's an example explaining the same idea in a different part of the Ludwig codepath: https://github.com/ludwig-ai/ludwig/blob/master/ludwig/schema/llms/peft.py#L443 |
||
|
||
for name, p in model.named_parameters(): | ||
if re.search(pattern, str(name)): | ||
p.requires_grad = False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import re | ||
from contextlib import nullcontext as no_error_raised | ||
|
||
import pytest | ||
|
||
from ludwig.api import LudwigModel | ||
from ludwig.constants import TRAINER | ||
from ludwig.encoders.image.torchvision import TVEfficientNetEncoder | ||
from ludwig.schema.trainer import BaseTrainerConfig | ||
from ludwig.utils.misc_utils import set_random_seed | ||
from ludwig.utils.trainer_utils import freeze_layers_regex | ||
from tests.integration_tests.utils import category_feature, generate_data, image_feature | ||
|
||
RANDOM_SEED = 130 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"regex", | ||
[ | ||
r"(features\.1.*|features\.2.*|features\.3.*|model\.features\.4\.1\.block\.3\.0\.weight)", | ||
r"(features\.1.*|features\.2\.*|features\.3.*)", | ||
r"(features\.4\.0\.block|features\.4\.\d+\.block)", | ||
r"(features\.5\.*|features\.6\.*|features\.7\.*)", | ||
r"(features\.8\.\d+\.weight|features\.8\.\d+\.bias)", | ||
], | ||
) | ||
def test_tv_efficientnet_freezing(regex): | ||
set_random_seed(RANDOM_SEED) | ||
|
||
pretrained_model = TVEfficientNetEncoder( | ||
model_variant="b0", use_pretrained=False, saved_weights_in_checkpoint=True, trainable=True | ||
) | ||
|
||
config = BaseTrainerConfig(layers_to_freeze_regex=regex) | ||
freeze_layers_regex(config, pretrained_model) | ||
for name, param in pretrained_model.named_parameters(): | ||
if re.search(re.compile(regex), name): | ||
assert not param.requires_grad | ||
else: | ||
assert param.requires_grad | ||
|
||
|
||
def test_frozen_tv_training(tmpdir, csv_filename): | ||
input_features = [image_feature(tmpdir)] | ||
output_features = [category_feature()] | ||
|
||
config = { | ||
"input_features": input_features, | ||
"output_features": output_features, | ||
TRAINER: { | ||
"layers_to_freeze_regex": r"(features\.1.*|features\.2.*|model\.features\.4\.1\.block\.3\.0\.weight)", | ||
"epochs": 1, | ||
"train_steps": 1, | ||
}, | ||
"encoder": {"type": "efficientnet", "use_pretrained": False}, | ||
} | ||
|
||
training_data_csv_path = generate_data(config["input_features"], config["output_features"], csv_filename) | ||
model = LudwigModel(config) | ||
|
||
with no_error_raised(): | ||
model.experiment( | ||
dataset=training_data_csv_path, | ||
skip_save_training_description=True, | ||
skip_save_training_statistics=True, | ||
skip_save_model=True, | ||
skip_save_progress=True, | ||
skip_save_log=True, | ||
skip_save_processed_input=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be good to add in the docs some example runs and outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ethanreidel To second @skanjila -- would it be possible to show an example of running this command in terms of how it is different from the existing one -- and an example output. Thank you very much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Certainly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qq: when you say you'd like an example, do you mean an example in the Ludwig docs or how would you prefer it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ethanreidel One option is to create an example in the
examples/
top level directory in Ludwig