Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU-based vectorized Specaug Version 2 #9155

Merged
merged 15 commits into from
May 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 25 additions & 22 deletions nemo/collections/asr/parts/submodules/spectr_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
time_width: int | float = 10,
rng: random.Random | None = None,
mask_value: float = 0.0,
fast: bool = True,
use_vectorized_code: bool = True,
):
super().__init__()

Expand All @@ -80,7 +80,7 @@ def __init__(
self.time_width = time_width

self.mask_value = mask_value
self.fast = fast
self.use_vectorized_code = use_vectorized_code

if isinstance(time_width, int):
self.adaptive_temporal_width = False
Expand All @@ -93,8 +93,8 @@ def __init__(
@typecheck()
@torch.no_grad()
def forward(self, input_spec, length):
if self.fast:
return self._forward_fast(input_spec, length)
if self.use_vectorized_code:
return self._forward_vectorized(input_spec, length)
else:
return self._forward_legacy(input_spec, length)

Expand Down Expand Up @@ -130,7 +130,7 @@ def _forward_legacy(self, input_spec, length):
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
return masked_spec

def _forward_fast(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
#time masks
input_spec = self._apply_masks(input_spec=input_spec, num_masks=self.time_masks,
length=length, width = self.time_width, axis=2, mask_value=self.mask_value)
Expand All @@ -142,40 +142,43 @@ def _forward_fast(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch
def _apply_masks(
self, input_spec: torch.Tensor, num_masks: int, length: torch.Tensor, width: int | float,
mask_value: float, axis: int) -> torch.Tensor:

batch_size = input_spec.shape[0]
axis_length = input_spec.shape[axis]

# We use float32 dtype for begin/end mask markers before they are quantized to long.
# Using x.dtype might cause us to encounter dtypes such as bf16 or smaller which
# wouldn't be able to represent every frame index leading to subtle bugs.
#If width is float then it is transformed into a tensor
if isinstance(width, float):
scaled_width = torch.clamp(width * length, max = axis_length).unsqueeze(1)
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * scaled_width
mask_width = mask_width.long()
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * (axis_length - mask_width)
mask_start.long()
else:
#Since we don't need to compute scaled width, we can call randint mask_width and mask_start
width = min(width, axis_length)
mask_start = torch.randint(low=0, high = max(1, axis_length - width), size=(batch_size, num_masks),
device=input_spec.device, dtype=torch.long)
mask_width = torch.randint(low=0, high = max(1, width), size=(batch_size, num_masks),
device=input_spec.device, dtype=torch.long)
mask_end = mask_start + mask_width
width = torch.clamp(width * length, max = axis_length).unsqueeze(1)

# Generate [0-1) random numbers and then scale the tensors.
# Use float32 dtype for begin/end mask markers before they are quantized to long.
# Using x.dtype might cause us to encounter dtypes such as bf16 or smaller which
# wouldn't be able to represent every frame index leading to subtle bugs.
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
mask_width = mask_width.long()
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * (axis_length - mask_width)
mask_start.long()
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
mask_end = mask_start + mask_width

# Create mask values using vectorized indexing
indices = torch.arange(axis_length, device=input_spec.device)
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
# Create a mask_tensor with all the indices.
# The mask_tensor shape is (batch_size, num_masks, axis_length).
mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
# Reduce masks to one mask
mask_tensor = mask_tensor.any(dim=1)

# Create a final mask that aligns with the full tensor
mask = torch.zeros_like(input_spec, dtype=torch.bool)
if axis == 2:
mask_ranges = mask_tensor[:, None, :]
amorari-nvidia marked this conversation as resolved.
Show resolved Hide resolved
elif axis == 1:
mask_ranges = mask_tensor[:, :, None]

else:
raise Exception("axis can be either 1 or 2")
mask[:, :, :] = mask_ranges

# Apply the mask value and return a new tensor
return input_spec.masked_fill(mask=mask, value=mask_value)

class SpecCutout(nn.Module, Typing):
Expand Down