Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom DPO Trainer CUDA OOM #1626

Open
TheGhoul21 opened this issue May 7, 2024 · 2 comments
Open

Custom DPO Trainer CUDA OOM #1626

TheGhoul21 opened this issue May 7, 2024 · 2 comments

Comments

@TheGhoul21
Copy link

Hi, I overwrote some implementation of the DPOTrainer to test this paper.

After half of the training is gone, there's a sudden spike in the GPU Memory usage which is quite hard to understand
Screenshot 2024-05-07 at 06 02 29

after a while being at 99.9% it just crashes with the following error:

CUDA out of memory. Tried to allocate 20.39 GiB. GPU � has a total capacity of 22.19 GiB of which 11.72 GiB is free. Process 751671 has 10.22 GiB memory in use. Of the allocated memory 7.48 GiB is allocated by PyTorch, and 2.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

I tried with different setups (2xA10g's, 1x, 1xT4 etc) but nothing changes, after a while it just goes brr. I also tried to have a smaller batch size at the cost of under-using the GPU for most of the time: after half the training the GPU is not enough anymore.
Since I was planning to contribute to this project with this custom loss function paper implementation I'm asking for help. Here's the code:

def concatenated_forward(
    self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    We do this to avoid doing two forward passes, because it's faster for FSDP.
    """
    concatenated_batch = self.concatenated_inputs(
        batch,
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
        padding_value=self.padding_value,
        device=self.accelerator.device,
    )
    len_chosen = batch["chosen_labels"].shape[0]

    model_kwargs = (
        {
            "labels": concatenated_batch["concatenated_labels"],
            "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
        }
        if self.is_encoder_decoder
        else {}
    )
    all_logits = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        use_cache=False,
        **model_kwargs,
    ).logits

    all_logps = self.get_batch_logps(
        all_logits,
        concatenated_batch["concatenated_labels"],
        average_log_prob=self.loss_type == "ipo",
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
    )


    def cross_entropy_loss(logits, labels):
        if not self.is_encoder_decoder:
            # Shift so that tokens < n predict n
            logits = logits[..., :-1, :].contiguous()
            labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = nn.CrossEntropyLoss()
        logits = logits.view(-1, logits.shape[-1])
        labels = labels.view(-1)
        # Enable model parallelism
        labels = labels.to(logits.device)
        loss = loss_fct(logits, labels)
        return loss

    if self.is_encoder_decoder:
        labels = concatenated_batch["concatenated_labels"].clone()
    else:
        labels = concatenated_batch["concatenated_input_ids"].clone()

    chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

    chosen_logps = all_logps[:len_chosen]
    rejected_logps = all_logps[len_chosen:]

    chosen_logits = all_logits[:len_chosen]
    rejected_logits = all_logits[len_chosen:]

    return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
    self,
    model,
    batch: Dict[str, Union[List, torch.LongTensor]],
    train_eval: Literal["train", "eval"] = "train",
):
    """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
    metrics = {}

    (
        policy_chosen_logps,
        policy_rejected_logps,
        policy_chosen_logits,
        policy_rejected_logits,
        chosen_nll_loss
    ) = self.concatenated_forward(model, batch)

    # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
    if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
        reference_chosen_logps = batch["reference_chosen_logps"]
        reference_rejected_logps = batch["reference_rejected_logps"]
    else:
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, batch)

    losses, chosen_rewards, rejected_rewards = self.dpo_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        reference_chosen_logps,
        reference_rejected_logps,
    )
    chosen_nll_loss.to(self.accelerator.device)
    loss = chosen_nll_loss + losses.mean()
    reward_accuracies = (chosen_rewards > rejected_rewards).float()

    prefix = "eval_" if train_eval == "eval" else ""
    metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
    metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
    metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
    metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
    metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
    metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
    metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
    metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

    return loss, metrics



def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
    """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
    compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

    # compute reference logps
    with torch.no_grad(), compte_ref_context_manager():
        if self.ref_model is None:
            with self.null_ref_context():
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                    _
                ) = self.concatenated_forward(self.model, padded_batch)
        else:
            (
                reference_chosen_logps,
                reference_rejected_logps,
                _,
                _,
                _
            ) = self.concatenated_forward(self.ref_model, padded_batch)

    return reference_chosen_logps, reference_rejected_logps
@TheGhoul21
Copy link
Author

I was thinking: could this be related to the fact that some sequences later in the training have a greater length?
Maybe if it's like double the length it'd need double the memory to be processed.

@younesbelkada
Copy link
Collaborator

Hi @TheGhoul21
Thanks a lot !
I think it's a bit hard to tell ... perhaps adding a torch.cuda.empty_cache() at the end of the loss computation might help but I am not sure here. In any case this could be a great contribution and useful for the community. Would you be happy to still open the PR for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants