diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md index 528f058e..dfc8ad5f 100644 --- a/candle-examples/examples/flux/README.md +++ b/candle-examples/examples/flux/README.md @@ -13,7 +13,7 @@ descriptions, ```bash cargo run --features cuda --example flux -r -- \ - --height 1024 --width 1024 + --height 1024 --width 1024 \ --prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" ``` diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index e8609a8f..17c406b4 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -156,7 +156,15 @@ fn run(args: Args) -> Result<()> { Model::Schnell => flux::model::Config::schnell(), }; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; - let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + let state = if quantized { + flux::sampling::State::new( + &t5_emb.to_dtype(candle::DType::F32)?, + &clip_emb.to_dtype(candle::DType::F32)?, + &img.to_dtype(candle::DType::F32)?, + )? + } else { + flux::sampling::State::new(&t5_emb, &clip_emb, &img)? + }; let timesteps = match model { Model::Dev => { flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))) @@ -187,6 +195,7 @@ fn run(args: Args) -> Result<()> { ×teps, 4., )? + .to_dtype(dtype)? } else { let model_file = match model { Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,