Quantized version of flux. (#2500)

* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
This commit is contained in:
Laurent Mazare
2024-09-26 10:23:43 +02:00
committed by GitHub
parent d01207dbf3
commit 10d47183c0
6 changed files with 555 additions and 26 deletions

View File

@ -13,7 +13,7 @@ descriptions,
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

View File

@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -40,6 +44,10 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
/// Use the faster kernels which are buggy at the moment.
#[arg(long)]
no_dmmv: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}
@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
run(args)
}