Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP at eval stage + DDP metrics #2518

Draft
wants to merge 37 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e61505c
ddp metrics sync seems to work + sync the update avg
Adel-Moumen Apr 20, 2024
d699f06
sync_average_loss fn
Adel-Moumen Apr 20, 2024
5b09ae7
update
Adel-Moumen Apr 21, 2024
81e7657
update distributed
Adel-Moumen Apr 22, 2024
f04b5b6
pre-commit
Adel-Moumen Apr 22, 2024
969fdb4
Merge remote-tracking branch 'speechbrain/develop' into sync_metrics
Adel-Moumen Apr 22, 2024
2a9b61b
docstring
Adel-Moumen Apr 22, 2024
dbbf8f7
move in metrics
Adel-Moumen Apr 22, 2024
1b3b6d0
update docstring
Adel-Moumen Apr 22, 2024
70dba5f
remove junk file
Adel-Moumen Apr 22, 2024
fcd05fe
update gather
Adel-Moumen Apr 22, 2024
436e5d1
remove outdated code
Adel-Moumen Apr 22, 2024
bba5f7b
fix unittest
Adel-Moumen Apr 22, 2024
017b4e2
add gather in sync_average_loss
Adel-Moumen Apr 22, 2024
de28a9e
unittest
Adel-Moumen Apr 22, 2024
3adf773
update
Adel-Moumen Apr 23, 2024
a24e802
test all metrics
Adel-Moumen Apr 23, 2024
bb44d02
docstring
Adel-Moumen Apr 23, 2024
83bd499
update reduce method
Adel-Moumen Apr 23, 2024
9ded249
update code
Adel-Moumen Apr 23, 2024
e88c44b
update code
Adel-Moumen Apr 23, 2024
3e77947
enable -> enable_progressbar
Adel-Moumen Apr 23, 2024
38ba5d2
fix path
Adel-Moumen Apr 23, 2024
b20ed99
fix name enable to be enable_progressbar
Adel-Moumen Apr 23, 2024
d0e1d85
fix potential issue with DDP on GPUs
Adel-Moumen Apr 23, 2024
84aa76c
update reduce + test
Adel-Moumen Apr 23, 2024
1bf09a4
remove comment
Adel-Moumen Apr 23, 2024
e6b6b60
reduce op
Adel-Moumen Apr 23, 2024
eadd3b8
simplify codebase by removing DistributedState and using provided fea…
Adel-Moumen Apr 24, 2024
734d261
utilities fn
Adel-Moumen Apr 24, 2024
5e897c7
add docstring gater_object
Adel-Moumen Apr 24, 2024
6caaed8
improve test to reflect diff between gather_obj and gather
Adel-Moumen Apr 24, 2024
398bbea
update tests metrics
Adel-Moumen Apr 24, 2024
760a3a4
comment bert tests
Adel-Moumen Apr 24, 2024
b601fef
add sacrebleu ci
Adel-Moumen Apr 24, 2024
39f0ca2
add again bert
Adel-Moumen Apr 24, 2024
aa42be2
Merge remote-tracking branch 'speechbrain/develop' into sync_metrics
Adel-Moumen Apr 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
run: |
pip install uv
uv pip install --system ctc-segmentation # ctc-segmentation is funky with uv due to their oldest-supported-numpy dependency
uv pip install --system -r requirements.txt torch==2.2.1+cpu torchaudio==2.2.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu k2==1.24.4.dev20240223+cpu.torch2.2.1 --find-links https://k2-fsa.github.io/k2/cpu.html kaldilm==1.15.1 spacy==3.7.4 flair==0.13.1
uv pip install --system -r requirements.txt torch==2.2.1+cpu torchaudio==2.2.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu k2==1.24.4.dev20240223+cpu.torch2.2.1 --find-links https://k2-fsa.github.io/k2/cpu.html kaldilm==1.15.1 spacy==3.7.4 flair==0.13.1 sacrebleu
uv pip install --system --editable . --no-deps # already installed pinned deps from requirements.txt, we're good
- name: Install sox
run: |
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/CTC/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
https://github.com/kpu/kenlm/archive/master.zip
# k2 # It is better to install k2 with the procedure listed here: https://k2-fsa.github.io/k2/installation/from_wheels.html
kaldilm==1.15
kaldilm==1.15.1
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ sample_rate: 16000
# With DDP batch_size is multiplied by N jobs
# Must be 3 per GPU to fit 32GB of VRAM
batch_size: 6
test_batch_size: 8
test_batch_size: 1

# Dataloader options
train_dataloader_opts:
Expand Down
12 changes: 6 additions & 6 deletions recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def compute_forward(self, batch, stage):
elif stage == sb.Stage.TEST:
p_tokens = test_searcher(p_ctc, wav_lens)

candidates = []
scores = []
if hasattr(self.hparams, "rescorer"):
candidates = []
scores = []

for batch in p_tokens:
candidates.append([hyp.text for hyp in batch])
scores.append([hyp.score for hyp in batch])
for batch in p_tokens:
candidates.append([hyp.text for hyp in batch])
scores.append([hyp.score for hyp in batch])

if hasattr(self.hparams, "rescorer"):
p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)

return p_ctc, wav_lens, p_tokens
Expand Down
9 changes: 3 additions & 6 deletions recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,9 @@ def compute_objectives(self, predictions, batch, stage):

predicted_words = [wrd.split(" ") for wrd in predicted_texts]
target_words = [wrd.split(" ") for wrd in batch.wrd]
self.wer_metrics[k].append(
batch.id, predicted_words, target_words
)
self.cer_metrics[k].append(
batch.id, predicted_words, target_words
)
ids = batch.id
self.wer_metrics[k].append(ids, predicted_words, target_words)
self.cer_metrics[k].append(ids, predicted_words, target_words)
# For TEST and VALID stages, the loss value is not exact.
# The <UNK> words have a target length (e.g., number of phones or characters) of 1.
# As such, sentences with <UNK> have a higher loss during CTC loss 'mean' reduction mode.
Expand Down
1 change: 0 additions & 1 deletion recipes/LibriSpeech/ASR/CTC/train_with_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def compute_objectives(self, predictions, batch, stage):
# Convert indices to words
target_words = undo_padding(tokens, tokens_lens)
target_words = self.tokenizer(target_words, task="decode_from_list")

self.wer_metric.append(ids, predicted_words, target_words)
self.cer_metric.append(ids, predicted_words, target_words)

Expand Down
58 changes: 42 additions & 16 deletions speechbrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Aku Rouhe 2021
* Andreas Nautsch 2022
* Sylvain de Langen 2023
* Adel Moumen 2023
* Adel Moumen 2023, 2024
"""

import argparse
Expand Down Expand Up @@ -1000,12 +1000,10 @@ def make_dataloader(
# TRAIN stage is handled specially.
if stage == sb.Stage.TRAIN:
loader_kwargs = self._train_loader_specifics(dataset, loader_kwargs)
# This commented-out code block is useful when one can ensure
# metric reporting is DDP-valid for VALID & EVAL datasets.
# elif self.distributed_launch:
# loader_kwargs = sb.dataio.dataloader.distributed_loader_specifics(
# self.distributed_launch, self.rank, dataset, loader_kwargs
# )
elif self.distributed_launch:
loader_kwargs = sb.dataio.dataloader.distributed_loader_specifics(
self.distributed_launch, self.rank, dataset, loader_kwargs
)
dataloader = sb.dataio.dataloader.make_dataloader(
dataset, **loader_kwargs
)
Expand Down Expand Up @@ -1382,7 +1380,7 @@ def evaluate_batch(self, batch, stage):
loss = self.compute_objectives(out, batch, stage=stage)
return loss.detach().cpu()

def _fit_train(self, train_set, epoch, enable):
def _fit_train(self, train_set, epoch, enable_progressbar):
# Training stage
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
Expand All @@ -1403,7 +1401,7 @@ def _fit_train(self, train_set, epoch, enable):
train_set,
initial=self.step,
dynamic_ncols=True,
disable=not enable,
disable=not enable_progressbar,
colour=self.tqdm_barcolor["train"],
) as t:
if self.profiler is not None:
Expand Down Expand Up @@ -1442,7 +1440,12 @@ def _fit_train(self, train_set, epoch, enable):
steps_since_ckpt = 0

# Run train "on_stage_end" on all processes
self.zero_grad(set_to_none=True) # flush gradients
# flush gradients
self.zero_grad(set_to_none=True)
# sync the avg_loss across all processes
self.avg_train_loss = sb.utils.distributed.reduce(
torch.tensor(self.avg_train_loss, device=self.device)
).item()
self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0
Expand Down Expand Up @@ -1478,7 +1481,7 @@ def _should_save_intra_epoch_ckpt(self, last_ckpt_time, steps_since_ckpt):
torch.distributed.broadcast_object_list(broadcast_list, src=0)
return broadcast_list[0]

def _fit_valid(self, valid_set, epoch, enable):
def _fit_valid(self, valid_set, epoch, enable_progressbar):
# Validation stage
if valid_set is not None:
self.on_stage_start(Stage.VALID, epoch)
Expand All @@ -1488,7 +1491,7 @@ def _fit_valid(self, valid_set, epoch, enable):
for batch in tqdm(
valid_set,
dynamic_ncols=True,
disable=not enable,
disable=not enable_progressbar,
colour=self.tqdm_barcolor["valid"],
):
self.step += 1
Expand All @@ -1500,6 +1503,10 @@ def _fit_valid(self, valid_set, epoch, enable):
break

self.step = 0
# sync the avg_loss across all processes
avg_valid_loss = sb.utils.distributed.reduce(
torch.tensor(avg_valid_loss, device=self.device)
).item()
self.on_stage_end(Stage.VALID, avg_valid_loss, epoch)

def fit(
Expand Down Expand Up @@ -1585,12 +1592,22 @@ def fit(
progressbar = not self.noprogressbar

# Only show progressbar if requested and main_process
enable = progressbar and sb.utils.distributed.if_main_process()
enable_progressbar = (
progressbar and sb.utils.distributed.if_main_process()
)

# Iterate epochs
for epoch in epoch_counter:
self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
self._fit_valid(valid_set=valid_set, epoch=epoch, enable=enable)
self._fit_train(
train_set=train_set,
epoch=epoch,
enable_progressbar=enable_progressbar,
)
self._fit_valid(
valid_set=valid_set,
epoch=epoch,
enable_progressbar=enable_progressbar,
)

# Debug mode only runs a few epochs
if (
Expand Down Expand Up @@ -1746,6 +1763,11 @@ def evaluate(
if progressbar is None:
progressbar = not self.noprogressbar

# Only show progressbar if requested and main_process
enable_progressbar = (
progressbar and sb.utils.distributed.if_main_process()
)

if not (
isinstance(test_set, DataLoader)
or isinstance(test_set, LoopedLoader)
Expand All @@ -1762,7 +1784,7 @@ def evaluate(
for batch in tqdm(
test_set,
dynamic_ncols=True,
disable=not progressbar,
disable=not enable_progressbar,
colour=self.tqdm_barcolor["test"],
):
self.step += 1
Expand All @@ -1773,6 +1795,10 @@ def evaluate(
if self.debug and self.step == self.debug_batches:
break

# sync the avg_loss across all processes
avg_test_loss = sb.utils.distributed.reduce(
Adel-Moumen marked this conversation as resolved.
Show resolved Hide resolved
torch.tensor(avg_test_loss, device=self.device)
).item()
self.on_stage_end(Stage.TEST, avg_test_loss, None)
self.step = 0
return avg_test_loss
Expand Down
5 changes: 5 additions & 0 deletions speechbrain/utils/Accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from speechbrain.dataio.dataio import length_to_mask
from speechbrain.utils.distributed_metrics import gather_for_metrics


def Accuracy(log_probabilities, targets, length=None):
Expand Down Expand Up @@ -86,6 +87,10 @@ def append(self, log_probabilities, targets, length=None):
length : torch.Tensor
Length of target (batch_size,).
"""
log_probabilities = gather_for_metrics(log_probabilities)
targets = gather_for_metrics(targets)
if length is not None:
length = gather_for_metrics(length)
numerator, denominator = Accuracy(log_probabilities, targets, length)
self.correct += numerator
self.total += denominator
Expand Down
4 changes: 4 additions & 0 deletions speechbrain/utils/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch

import speechbrain as sb
from speechbrain.lobes.models.huggingface_transformers import TextEncoder
from speechbrain.utils.distances import cosine_similarity_matrix
from speechbrain.utils.metric_stats import MetricStats
Expand Down Expand Up @@ -104,6 +105,9 @@ def append(self, ids, predict, target):
targets: list
the ground truths in tokenizable format
"""
ids = sb.utils.distributed_metrics.gather_for_metrics(ids)
predict = sb.utils.distributed_metrics.gather_for_metrics(predict)
target = sb.utils.distributed_metrics.gather_for_metrics(target)
self.ids.extend(ids)
self.predictions.extend(predict)
self.targets.extend(target)
Expand Down
6 changes: 5 additions & 1 deletion speechbrain/utils/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* Mirco Ravanelli 2021
"""

import speechbrain as sb
from speechbrain.utils.metric_stats import MetricStats


Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, merge_words=True, max_ngram_order=4):
try:
from sacrebleu.metrics import BLEU
except ImportError:
print(
raise ImportError(
"Please install sacrebleu (https://pypi.org/project/sacrebleu/) in order to use the BLEU metric"
)

Expand All @@ -83,6 +84,9 @@ def append(self, ids, predict, targets, ind2lab=None):
Callable that maps from indices to labels, operating on batches,
for writing alignments.
"""
ids = sb.utils.distributed_metrics.gather_for_metrics(ids)
predict = sb.utils.distributed_metrics.gather_for_metrics(predict)
targets = sb.utils.distributed_metrics.gather_for_metrics(targets)
self.ids.extend(ids)

if ind2lab is not None:
Expand Down
67 changes: 67 additions & 0 deletions speechbrain/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

import torch

from speechbrain.utils.distributed_utils import (
distributed_is_initialized,
recursively_apply,
)

MAIN_PROC_ONLY: int = 0


Expand Down Expand Up @@ -206,3 +211,65 @@ def ddp_init_group(run_opts):
rank=rank,
timeout=datetime.timedelta(seconds=7200),
)


def reduce(tensor, reduction="mean"):
"""Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors
asumagic marked this conversation as resolved.
Show resolved Hide resolved
across all processes by the mean of a given operation.

For instance, if you have a list of tensors on each process such as a loss, you can use this function to
synchronize the tensors across all processes and then reduce them by the mean or sum. This is particularly
useful when you want to calculate the mean loss across all processes.

Arguments
---------
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to reduce.
reduction (`str`, *optional*, defaults to `"mean"`):
A reduction method. Can be of "mean", or "sum".

Returns
-------
The same data structure as `data` with all the tensors reduced.

Example
-------
>>> tensor = torch.arange(2) + 1 + 2 * rank # doctest: +SKIP
>>> tensor # doctest: +SKIP
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> reduce(tensor, reduction="sum") # doctest: +SKIP
tensor([4, 6]) # Rank 0 and 1 combined
>>> reduce(tensor, reduction="mean") # doctest: +SKIP
tensor([2, 3]) # Rank 0 and 1 combined
>>> obj = [{"a": [(torch.arange(2) + 1 + 2 * rank).float() for _ in range(4)]}] # doctest: +SKIP
[{'a': [tensor([1., 2.]), tensor([1., 2.]), tensor([1., 2.]), tensor([1., 2.])]}] # Rank 0
[{'a': [tensor([3., 4.]), tensor([3., 4.]), tensor([3., 4.]), tensor([3., 4.])]}] # Rank 1
>>> reduce(obj, reduction="sum") # doctest: +SKIP
[{'a': [tensor([4., 6.]), tensor([4., 6.]), tensor([4., 6.]), tensor([4., 6.])]}] # Rank 0 and 1 combined
>>> reduce(obj, reduction="mean") # doctest: +SKIP
[{'a': [tensor([2., 3.]), tensor([2., 3.]), tensor([2., 3.]), tensor([2., 3.])]}] # Rank 0 and 1 combined
"""

def _reduce_across_processes(tensor, reduction="mean"):
cloned_tensor = tensor.clone()

if not distributed_is_initialized():
return cloned_tensor
else:
from torch.distributed import ReduceOp

torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)

# we cannot use 'ReduceOp.AVG` since it is unavailable for gloo backend
if reduction == "mean":
cloned_tensor /= torch.distributed.get_world_size()

return cloned_tensor

return recursively_apply(
_reduce_across_processes,
tensor,
error_on_other_type=True,
reduction=reduction,
)