Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any specific reason sampling is not in FP16? #377

Open
danbochman opened this issue Feb 19, 2024 · 1 comment
Open

Any specific reason sampling is not in FP16? #377

danbochman opened this issue Feb 19, 2024 · 1 comment

Comments

@danbochman
Copy link

During training the forward method casts to FP16 but during sampling no

    @torch.no_grad()
    @cast_torch_tensor
    def sample(self, *args, **kwargs):

        self.print_untrained_unets()
        if not self.is_main:
            kwargs["use_tqdm"] = False

        output = self.imagen.sample(*args, device=self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16=True)
    def forward(self, *args, unet_number=None, **kwargs):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        assert (
            not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number
        ), f"you can only train unet #{self.only_train_unet_number}"

        with self.accelerator.accumulate(self.unet_being_trained):
            with self.accelerator.autocast():
                loss = self.imagen(*args, unet=self.unet_being_trained, unet_number=unet_number, **kwargs)

            if self.training:
                self.accelerator.backward(loss)

        return loss

I tried casting to FP16 and something in the loop changes to float32 even if the inputs are float16
I wonder if you have already encountered that and if that's the reason there's no casting to FP16 during sampling

Best regards and thanks for the great repo,

@rajnish159
Copy link

how to run this code, please provide the step for text to image , please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants