Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training examples] reduce complexity by running final validations before export #7959

Open
bghira opened this issue May 16, 2024 · 3 comments
Assignees

Comments

@bghira
Copy link
Contributor

bghira commented May 16, 2024

I was thinking about Sayak's suggestion lately that the training examples are too long, and went through looking for redundant/unnecessary code sections that can be reduced or eliminated for readability.

The main thing that stands out is how the validations occur during the trainer unwind stage.

During training, we have access to the unet and other components - we pass is_final_validation=False to the log_validations method, which behaves differently across different training examples. In the ControlNet example, it ends up importing the ControlNet model as a pipeline from the args.output_dir:

def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
    logger.info("Running validation... ")

    if not is_final_validation:
        controlnet = accelerator.unwrap_model(controlnet)
        pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            vae=vae,
            unet=unet,
            controlnet=controlnet,
            revision=args.revision,
            variant=args.variant,
            torch_dtype=weight_dtype,
        )
    else:
        controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
        if args.pretrained_vae_model_name_or_path is not None:
            vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
        else:
            vae = AutoencoderKL.from_pretrained(
                args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
            )

        pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            vae=vae,
            controlnet=controlnet,
            revision=args.revision,
            variant=args.variant,
            torch_dtype=weight_dtype,
        )

this seems to only happen because at the end of training, this method is called after everything is unloaded:

    # Create the pipeline using using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        controlnet = unwrap_model(controlnet)
        controlnet.save_pretrained(args.output_dir)

        # Run a final round of validation.
        # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
        image_logs = None
        if args.validation_prompt is not None:
            image_logs = log_validation(
                vae=None,
                unet=None,
                controlnet=None,
                args=args,
                accelerator=accelerator,
                weight_dtype=weight_dtype,
                step=global_step,
                is_final_validation=True,
            )
  • theoretically it shows us the results of the final export, but in practice it's the same result as if we inference on the loaded weights without reloading them
  • this particular case has the unet/vae/controlnet model loaded twice, as it would need to del them before loading the new ones
  • when max_train_steps = 1000 and validation_steps = 100 or some other value that goes evenly into max_train_steps, we run two validations - one just before exiting the training loop, and then this one
  • unnecessary slowdown on systems that (for no reason I can discern) take a very long time to load pipelines

If we just remove the final inference code, the earlier condition can be updated to run the validation before exiting the loop, which would solve issues 2-4:

From:

                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                        image_logs = log_validation(...)

To:

                    if args.validation_prompt is not None and (global_step % args.validation_steps == 0 or global_step >= args.max_train_steps):
                        image_logs = log_validation(...)
@christopher-beckham
Copy link
Contributor

christopher-beckham commented May 27, 2024

I also think that code seems unnecessarily complicated. The scripts should be set up so that the pipeline gets constructed first, then components get pulled out as necessary depending on what needs to be finetuned, e.g. the unet. Right now it looks like the opposite -- various components get instantiated depending on what is needed in the training script, then a pipeline gets created each time validation happens.

If I can add some complementary thoughts: in my own versions of those training scripts, I actually construct a pipeline before training has started, and whenever log_validation is called I basically pass the pipeline in and that gets used instead. The only "gotcha" I found is that it's going to mutate the state of pipeline.scheduler because calling the pipeline implies a certain number of inference steps and the pipeline will subsequently mutate scheduler state (e.g. pipeline.scheduler.sigmas) (a decision choice I really depise, intentional or not). This means that if you're calling log_validation half-way through training, as soon as log_validation finishes and training resumes your script may crash because some internal state got changed. Therefore my log_validation method looks like:

def log_validation(...)

  original_scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config)

  # validation code
  ...

  pipeline.scheduler = original_scheduler
  
  return image_logs

@sayakpaul
Copy link
Member

Thanks for initiating the discussion! @christopher-beckham, Thanks for your thoughts, too. It's all very reasonable, especially the one on the scheduler.

However, we don't use the noise_scheduler (the one used for training) in the pipeline while performing the validation inference steps. So, I don't understand it fully. Could you point me to a LoC in the library that doesn't follow this?

The scripts should be set up so that the pipeline gets constructed first, then components get pulled out as necessary depending on what needs to be finetuned, e.g. the unet.

Running validation inference is conditional. So, how would constructing a pipeline no matter what would be sensible here?

Right now it looks like the opposite -- various components get instantiated depending on what is needed in the training script, then a pipeline gets created each time validation happens.

The pipeline gets created with already instantiated components, so, I don't understand the negative implications, fully.

Finally, as pointed out, I find it better to be explicit about this

theoretically it shows us the results of the final export, but in practice it's the same result as if we inference on the loaded weights without reloading them

It mimics the situation when a user would load up the trained model in a fresh environment. Even if it takes more time, I think it's better to be explicit in this case.

@christopher-beckham
Copy link
Contributor

christopher-beckham commented May 30, 2024

Hi @sayakpaul,

Thanks for your input.

However, we don't use the noise_scheduler (the one used for training) in the pipeline while performing the validation inference steps. So, I don't understand it fully. Could you point me to a LoC in the library that doesn't follow this?

Good point. This is something I missed. I think in some old code of mine I originally took from this repo the training scheduler used was Euler (one based on a controlnet). But maybe to keep the discussion clean I can just solely make reference to the current version of the SDXL training script (which indeed does use DDPMScheduler for training).

Running validation inference is conditional. So, how would constructing a pipeline no matter what would be sensible here?

In my own refactoring of the script I was trying to clean up this. Every individual component (the two text encoders, the tokenizers, etc.) is instantiated individually which just seems overly verbose to me. If each component had to come from a different model_id I would understand, but right now they're all being instantiated with the same model id, i.e. args.pretrained_model_name_or_path (with the exception of the VAE). I didn't get why that code just couldn't be replaced with an instantiation of the entire pipeline as follows, e.g.

pipeline = StableDiffusionXLPipeline.from_pretrained(args.pretrained_model_name_or_path)

and then pull things out as needed for training, e.g. unet = pipeline.unet, text_encoder_one = pipeline.text_encoder, etc. Indeed one caveat is that the scheduler needs to be switched to DDPMScheduler but you could just manually replace it with pipe.scheduler = DDPMScheduler.from_pretrained(...) and that solves that. Similarly for the VAE. That's basically 3 lines of code instead of what you have currently.

If one went with that design decision, then it would be convenient to just also pass in the pipeline into log_validation, but since the inference scheduler is different (it's EulerDiscreteScheduler) you could just internally construct a new pipeline from the existing pipeline's components and also pass scheduler= for the Euler one.

The pipeline gets created with already instantiated components, so, I don't understand the negative implications, fully.

Yes I agree, I completely forgot it's cheap to do this given the components are already instantiated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants