Skip to content

Influence Functions with (Eigenvalue-corrected) Kronecker-Factored Approximate Curvature

License

Notifications You must be signed in to change notification settings

pomonam/kronfluence

Repository files navigation

Kronfluence

License License CI Linting Ruff


Kronfluence is a research repository designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For a detailed description of the methodology, see the paper Studying Large Language Model Generalization with Influence Functions.


Warning

This repository is under active development and has not reached its first stable release.

Installation

Important

Requirements:

  • Python: Version 3.9 or later
  • PyTorch: Version 2.1 or later

To install the latest stable version, use the following pip command:

pip install kronfluence

Alternatively, you can install directly from source:

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .

Getting Started

Kronfluence supports influence computations on nn.Linear and nn.Conv2d modules. See the Technical Documentation page for a comprehensive guide.

Learn More

The examples folder contains several examples demonstrating how to use Kronfluence. More examples will be added in the future. TL;DR You need to prepare a trained model and datasets, and pass them into the Analyzer class.

import torch
import torchvision
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model

# Define the model and load the trained model weights.
model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))

# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)
eval_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)

# Define the task. See the Technical Documentation page for details.
task = MnistTask()

# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
    scores_name="my_scores",
    factors_name="my_factors",
    query_dataset=eval_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")

Contributing

Contributions are welcome! To get started, please review our Code of Conduct. For bug fixes, please submit a pull request. If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.

Setting Up Development Environment

To contribute to Kronfluence, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"

Acknowledgements

Omkar Dige contributed to the profiling, DDP, and FSDP utilities, and Adil Asif provided valuable insights and suggestions on structuring the DDP and FSDP implementations. I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

License

This software is released under the Apache 2.0 License, as detailed in the LICENSE file.