Release the mmdit model earlier to reduce memory usage. (#2581)

* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.

* Release the mmdit model earlier to reduce memory usage.
This commit is contained in:
Laurent Mazare
2024-10-28 16:06:53 +01:00
committed by GitHub
parent 0e2c8c17fb
commit 498bc2cdc9

View File

@ -183,17 +183,17 @@ fn main() -> Result<()> {
let context = Tensor::cat(&[context, context_uncond], 0)?; let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?; let y = Tensor::cat(&[y, y_uncond], 0)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new( let mmdit = MMDiT::new(
&mmdit_config, &mmdit_config,
use_flash_attn, use_flash_attn,
vb.pp("model.diffusion_model"), vb.pp("model.diffusion_model"),
)?; )?;
sampling::euler_sample(
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit, &mmdit,
&y, &y,
&context, &context,
@ -202,7 +202,8 @@ fn main() -> Result<()> {
time_shift, time_shift,
height, height,
width, width,
)?; )?
};
let dt = start_time.elapsed().as_secs_f32(); let dt = start_time.elapsed().as_secs_f32();
println!( println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",