Skip to content

A simple Jax implementation of influence functions.

License

Notifications You must be signed in to change notification settings

pomonam/jax-influence

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-Influence

License: Apache 2.0

JAX-Influence is a JAX implementation of influence functions, a classical technique from robust statistics that estimates the impact of removing a single training data point on a model's learned parameters. This repository complements the paper "If Influence Functions are the Answer, Then What is the Question?".

The repository aims to provide a simple and minimal implementation of influence functions in JAX. For those interested in implementations in other frameworks, a PyTorch version is available here, and a PyTorch EK-FAC implementation can be found here.

Installation

To install JAX-Influence, you can use pip to install from the source:

git clone https://github.com/pomonam/jax-influence
 
cd jax-influence
pip install -e .   
pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' # Replace `jax_gpu` with `jax_cpu` if you wish to install the CPU version.

Contributors

About

A simple Jax implementation of influence functions.

Topics

Resources

License

Stars

Watchers

Forks

Languages