Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU-based vectorized Specaug Version 2 #9155

Merged
merged 15 commits into from
May 15, 2024
Merged
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 90 additions & 1 deletion nemo/collections/asr/parts/submodules/spectr_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class SpecAugment(nn.Module, Typing):
to be cut in one segment.
If a float value, defines maximum percentage of timesteps that
are cut adaptively.
fast - GPU-based implementation with batched masking and GPU rng,
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
setting it to False reverts to the legacy implementation.
Fast implementation is inspired by torchaudio:
https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
"""

@property
Expand All @@ -56,7 +60,14 @@ def output_types(self):
return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}

def __init__(
self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0,
self,
freq_masks: int = 0,
time_masks: int = 0,
freq_width: int = 10,
time_width: int | float = 10,
rng: random.Random | None = None,
mask_value: float = 0.0,
use_vectorized_code: bool = True,
):
super().__init__()

Expand All @@ -69,6 +80,7 @@ def __init__(
self.time_width = time_width

self.mask_value = mask_value
self.use_vectorized_code = use_vectorized_code

if isinstance(time_width, int):
self.adaptive_temporal_width = False
Expand All @@ -81,6 +93,12 @@ def __init__(
@typecheck()
@torch.no_grad()
def forward(self, input_spec, length):
if self.use_vectorized_code:
return self._forward_vectorized(input_spec, length)
else:
return self._forward_legacy(input_spec, length)

def _forward_legacy(self, input_spec, length):
batch_size, num_freq_bins, _ = input_spec.shape
# Move lengths to CPU before repeated indexing
lengths_cpu = length.cpu().numpy()
Expand Down Expand Up @@ -112,6 +130,77 @@ def forward(self, input_spec, length):
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
return masked_spec

def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
# time masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.time_masks,
length=length,
width=self.time_width,
axis=2,
mask_value=self.mask_value,
)
# freq masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.freq_masks,
length=length,
width=self.freq_width,
axis=1,
mask_value=self.mask_value,
)
return input_spec

def _apply_masks(
self,
input_spec: torch.Tensor,
num_masks: int,
length: torch.Tensor,
width: int | float,
mask_value: float,
axis: int,
) -> torch.Tensor:

batch_size = input_spec.shape[0]
axis_length = input_spec.shape[axis]

# If width is float then it is transformed into a tensor
if isinstance(width, float):
width = torch.clamp(width * length, max=axis_length).unsqueeze(1)

# Generate [0-1) random numbers and then scale the tensors.
# Use float32 dtype for begin/end mask markers before they are quantized to long.
# Using x.dtype might cause us to encounter dtypes such as bf16 or smaller which
# wouldn't be able to represent every frame index leading to subtle bugs.
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
mask_width = mask_width.long()
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * (
axis_length - mask_width
)
mask_start.long()
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
mask_end = mask_start + mask_width

# Create mask values using vectorized indexing
indices = torch.arange(axis_length, device=input_spec.device)
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
# Create a mask_tensor with all the indices.
# The mask_tensor shape is (batch_size, num_masks, axis_length).
mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
# Reduce masks to one mask
mask_tensor = mask_tensor.any(dim=1)

# Create a final mask that aligns with the full tensor
mask = torch.zeros_like(input_spec, dtype=torch.bool)
if axis == 2:
mask_ranges = mask_tensor[:, None, :]
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
elif axis == 1:
mask_ranges = mask_tensor[:, :, None]
else:
raise Exception("axis can be either 1 or 2")
mask[:, :, :] = mask_ranges

# Apply the mask value and return a new tensor
return input_spec.masked_fill(mask=mask, value=mask_value)


class SpecCutout(nn.Module, Typing):
"""
Expand Down