Skip to content

Latest commit

 

History

History
967 lines (635 loc) · 18.5 KB

slides.md

File metadata and controls

967 lines (635 loc) · 18.5 KB
theme title info class highlighter drawings transition mdc themeConfig hideInToc
default
Data valuation for machine learning
## A primer on data valuation and attribution Some examples of how to attribute data sources and how to value data in your projects using pyDVL. Learn more at [pydvl.org](https://pydvl.org)
text-center
shiki
persist
slide-left
true
primary
#084059
true

Data valuation for ML

Detecting mislabelled and out-of-distribution samples with pyDVL

PyData logo

Miguel de Benito Delgado - Kristof Schröder

appliedAI Institute logo


layout: fact hideInToc: true title: What

What is data valuation?


title: What is data valuation? level: 1 layout: two-cols-header class: self-center text-center p-6 transition: fade-out

We are interested in

the contribution of a training point to...

or

::left::

the overall model performance

("global" methods: Data Shapley & co.)

::right::

a single prediction

("local" methods: influences)


hideInToc: true layout: center

Global valuation methods


level: 1 layout: two-cols-header class: px-6

Two examples of how to measure contribution

utility(some_data) := model.fit(some_data).score(validation)

Take one training point $x \in T$


::left::

1: Contribution to the whole dataset
score_with = u(train)
score_without = u(train.drop(x))
value = score_with - score_without
Leave-One-Out

$$\text{value}(x) = u(T) - u(T \setminus {x})$$

$n$ retrainings

low signal

::right::

2: Contribution to subsets
for subset in sampler.from_data(train.drop(x)):
  scores_with.append[u(subset.union({x}))]
  scores_without.append[u(subset)]
value = weighted_mean(scores_with - scores_without, coefficients)
Semivalue (e.g. Data Shapley)

$$\text{value}(x) = \sum_{S \subseteq T \setminus {x}} w(S) \left[ u(S \cup {x}) - u(S) \right]$$

$2^{n-1}$ retrainings (naive)


layout: fact hideInToc: true

What can data valuation do for you?


level: 1 title: "Example 1: Data cleaning" layout: two-cols class: p-4 table-center

Example 1: Data cleaning


Data dropped MAE improvement
10% 8% (+- 2%)
15% 10% (+- 3%)

::right::

Three steps

// First example
```python {none|1-2|3-4|5-7|all}
train, val, test = load_spotify_dataset(...)
model = GradientBoostingRegressor(...)
scorer = SupervisedScorer("accuracy", val)
utility = Utility(model, scorer)
valuation = DataShapleyValuation(utility, ...)
with joblib.parallel_backend("loky", n_jobs=16):
    valuation.fit(train)
```

```python {2,3}
train, val, test = load_data()
model = AnyModel()
scorer = CustomScorer(val)
utility = Utility(model, scorer)
valuation = DataShapleyValuation(utility, ...)
with joblib.parallel_backend("loky", n_jobs=16):
    valuation.fit(train)
```

```python {5}
train, val, test = load_data()
model = AnyModel()
scorer = CustomScorer(val)
utility = Utility(model, scorer)
valuation = AnyValuationMethod(utility, ...)
with joblib.parallel_backend("loky", n_jobs=16):
    valuation.fit(train)
```

```python {6,7}
train, val, test = load_data()
model = AnyModel()
scorer = CustomScorer(val)
utility = Utility(model, scorer)
valuation = AnyValuationMethod(utility, ...)
with joblib.parallel_backend("ray", n_jobs=480):
    valuation.fit(train)
```

and

values = valuation.values(sort=True)
clean_data = data.drop_indices(values[:100].indices)

model.fit(clean_data)
assert model.score(test) > 1.02 * previous_score

Profit!

new interface v0.10

title: Other tasks level: 1 layout: two-cols-header class: p-6

What can data valuation do for you?

::left::

  • We increased accuracy by removing bogus points
  • Better: select data for inspection
  • Data debugging
    what's wrong with this data?
  • Model debugging
    why are these data detrimental?

::right::

But also

  • Data acquisition: prioritize data sources
  • Attribution: find the most important data points

And more speculatively

  • Continual learning: compress your dataset
  • Data markets: price your data
  • Improve fairness metrics
  • ...

layout: fact hideInToc: true

What do you need?


title: Requirements level: 1 layout: two-cols class: px-6 table-invisible

Requirements

  • Any scikit-learn model
  • Or a wrapper with a fit() method
  • A scoring function
  • An imperfect dataset
pip install pydvl

pyDVL logo

::right::

What frameworks?

  • numpy and sklearn
  • joblib for parallelization
  • memcached for caching
  • Influence Functions use pytorch
  • Planned: allow jax and torch everywhere
  • dask for large datasets




pyDVL is still evolving!


layout: two-cols-center title: Problems with data valuation level: 1

Where's the catch?

  • Computational cost
  • Has my approximation converged?
  • Consistency across runs
  • Model and metric dependence

::right::


Some solutions

  • Monte-Carlo approximations
  • Efficient subset sampling strategies
  • Proxy models (value transfer)
  • Model-specific methods (KNN-Shap, Data-OOB, ...)
  • Utility learning (YMMV)

layout: fact title: Influence functions level: 1

The influence of a training point


title: The influence of a training point level: 1 layout: two-cols-header class: table-center p-6

The influence of a training point

::left::

Data Test loss
${z_1, z_2, ..., z_n}$ (... train ...) $\to$ $L(z)$
${z_1, \red{\sout{z_2}}, ..., z_n}$ (... train ...) $\to$ $L_{\red{-z_2}}(z)$

The "influence" of $z_2$ on test point $z$ is roughly

$$L(z) - L_{-z_2}(z)$$

::right::

  • One value per training / test point pair $(z_i, z)$
  • A full retraining per training point!
  • However: $$I(z_i, z) = \nabla_\theta L^\top \cdot H^{-1}{\theta} \cdot \nabla\theta L$$
  • Implicit computation and approximations
  • Are they good?
  • Does it matter?

layout: two-cols-header level: 1 class: table-invisible table-center py-6 no-bullet-points

Example 2: Finding mislabeled cells

::left::

  • NIH dataset with ~28K images for malaria screening1
  • Goal: detect these data points with pyDVL
Uninfected Infected
Uninfected cell Infected cell

::right::

```python {hide|1-4|5|7-8|10|all|4,7}
torch_model = ...  # Trained model
train, test = ... # Dataloaders

if_model = DirectInfluence(torch_model, loss, ...)
if_model.fit(train)

if_calc = SequentialInfluenceCalculator(if_model)
lazy_values = if_calc.influences(test, train)

values = lazy_values.to_zarr(path, ...)  # memmapped
```

```python {4,7}
torch_model = ...  # Trained model
train, test = ... # Dataloaders

if_model = ArnoldiInfluence(torch_model, loss, ...)
if_model.fit(train)

if_calc = SequentialInfluenceCalculator(if_model)
lazy_values = if_calc.influences(test, train)

values = lazy_values.to_zarr(path, ...)  # memmapped
```

```python {4,7}
torch_model = ...  # Trained model
train, test = ... # Dataloaders

if_model = NystroemSketchInfluence(torch_model, loss, ...)
if_model.fit(train)

if_calc = SequentialInfluenceCalculator(if_model)
lazy_values = if_calc.influences(test, train)

values = lazy_values.to_zarr(path, ...)  # memmapped
```

```python {4,7}
torch_model = ...  # Trained model
train, test = ... # Dataloaders

if_model = NystroemSketchInfluence(torch_model, loss, ...)
if_model.fit(train)

if_calc = DaskInfluenceCalculator(if_model)
lazy_values = if_calc.influences(test, train)

values = lazy_values.to_zarr(path, ...)   # memmapped
```

(Plus CG, LiSSa, E-KFAC, ...)


title: "Example 2: Procedure" level: 1 layout: two-cols-header class: p-6

::left::

Procedure

  • Compute all pairs of influences
  • For each training point: 25th percentile of influences (same labels)

::right::

Cells labelled as healthy Cells labelled as healthy

Cells labelled as parasitized

Cells labelled as parasitized


layout: two-cols-header level: 1 class: p-6

Accelerating IF computation


::left::

Problems

  • Computational complexity: $H^{-1}{\theta} \nabla\theta L$
  • Memory complexity: how many gradients fit on the device?

::right::

What can we do?

  • Approximation of the inverse Hessian vector product
  • Parallelization
  • Out-of-core computation

::bottom::

```python {1,3}
if_model = DirectInfluence(torch_model, loss, ...)
(...)
if_calc = SequentialInfluenceCalculator(if_model)
```

```python {1,3-5}
if_model = NystroemSketchInfluence(torch_model, loss, rank=10, ...)
(...)
client = Client(LocalCUDACluster())
if_calc = DaskInfluenceCalculator(if_model, client)
```

title: Picking methods level: 1 layout: two-cols-header class: p-6 text-center no-bullet-points

How to choose between IF and DV?



::left::

Influence functions

  • Large models with costly retrainings
  • torch interface
  • Point to point valuation

::right::

Data valuation

  • Smaller models
  • sklearn interface
  • Value over a test set

::bottom::

These are tools for data debugging!


layout: two-cols title: Thank you! hideInToc: true class: text-center table-center table-invisible p-6

Thank you for your attention!


slides and code:

pydata2024.pydvl.org

::right::

PyDVL contributors

Anes Benmerzoug Miguel de Benito Delgado Janoś Gabler
Jakob Kruse Markus Semmler Fabio Peruzzo
Kristof Schröder Bastien Zim YouYou!

layout: end

Appendix


layout: image-right image: data-valuation-taxonomy.svg backgroundSize: contain class: invertible

Many methods for data valuation

It's a growing field 1

  • Fit before, during, or after trainig
  • With or without reference datasets
  • Specific to classification / regression / unsupervised
  • Different model assumptions (from none to strong)
  • Local and global valuation

layout: two-cols-header dragPos: square: 841,363,20,20,270

Computing values with pyDVL

Three steps for all valuation methods

::left::

  • Prepare Dataset and model
  • Choose Scorer and Utility
  • Compute values (contribution to performance)




::right::

train, test = Dataset.from_sklearn(load_iris(), train_size=0.6)
model = LogisticRegression()
scorer = SupervisedScorer("accuracy", test)
utility = Utility(model, scorer)
valuation = DataBanzhafValuation(
    utility, MSRSampler(), RankCorrelation()
)
with joblib.parallel_backend("ray", n_jobs=48):
    valuation.fit(train)
<style> li { padding-top:2rem; } </style>

transition: fade-out hideInToc: true

Contents

Footnotes

  1. https://www.kaggle.com/datasets/paradisejoy/top-hits-spotify-from-20002019 2 3