Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducibility when using a model with batch norm #1732

Closed
4 tasks
BenjaminBossan opened this issue May 14, 2024 · 0 comments · Fixed by #1734
Closed
4 tasks

Reproducibility when using a model with batch norm #1732

BenjaminBossan opened this issue May 14, 2024 · 0 comments · Fixed by #1734
Labels
help wanted Extra attention is needed

Comments

@BenjaminBossan
Copy link
Member

BenjaminBossan commented May 14, 2024

System Info

Latest version of PEFT

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

model_id = "microsoft/resnet-18"

@pytest.fixture
def image_processor():
    image_processor = AutoImageProcessor.from_pretrained(model_id)
    return image_processor

@pytest.fixture
def data(image_processor):
    dataset = load_dataset("huggingface/cats-image")
    image = dataset["test"]["image"][0]
    return image_processor(image, return_tensors="pt")

def test_model_with_batchnorm(tmp_path, data):
    torch.manual_seed(0)
    model = AutoModelForImageClassification.from_pretrained(model_id)
    config = LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"])
    model = get_peft_model(model, config)

    # record outputs before training
    model.eval()
    with torch.inference_mode():
        output_before = model(**data)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    batch_size = 4
    max_steps = 5 * batch_size
    labels = torch.zeros(1, 1000)
    labels[0, 283] = 1
    for i in range(0, max_steps, batch_size):
        optimizer.zero_grad()
        outputs = model(**data, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.inference_mode():
        output_after = model(**data)
    assert torch.isfinite(output_after.logits).all()
    atol, rtol = 1e-4, 1e-4
    # sanity check: model was updated
    assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)

    # check saving the model and loading it
    model.save_pretrained(tmp_path)
    del model
    torch.manual_seed(0)
    model = AutoModelForImageClassification.from_pretrained(model_id)
    model = PeftModel.from_pretrained(model, tmp_path).eval()
    with torch.inference_mode():
        output_loaded = model(**data)
    # THIS FAILS
    assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)

Expected behavior

After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.

One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.

Edit: Best suggestion so far by @kashif: Check for the track_running_stats and if it's True, save the module's buffer. This will not cover all possible corner cases, but hopefully most.

@BenjaminBossan BenjaminBossan added the help wanted Extra attention is needed label May 14, 2024
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue May 15, 2024
Fixes huggingface#1732

After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.

This PR fixes this issue by saving the buffers of the batch norm layers.
They are identified by checking for the presence of the
track_running_stats attribute.

Note: One test for BOFT is currently failing, see the comment in the
test file.
BenjaminBossan added a commit that referenced this issue May 22, 2024
Fixes #1732

After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.

The fix is to add the batch norm layers to modules_to_save. This fix is
now documented and tested.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant