Stable diffusion 3.5 support. (#2578)

* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
This commit is contained in:
Laurent Mazare
2024-10-27 10:01:04 +01:00
committed by GitHub
parent 07849aa595
commit 37e0ab8c64
5 changed files with 209 additions and 85 deletions

View File

@ -30,7 +30,7 @@ pub fn euler_sample(
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,