Support the flux-dev model too. (#2395)

This commit is contained in:
Laurent Mazare
2024-08-04 11:16:24 +01:00
committed by GitHub
parent c0a559d427
commit 89eae41efd

View File

@ -37,6 +37,15 @@ struct Args {
#[arg(long)] #[arg(long)]
decode_only: Option<String>, decode_only: Option<String>,
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum Model {
Schnell,
Dev,
} }
fn run(args: Args) -> Result<()> { fn run(args: Args) -> Result<()> {
@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
width, width,
tracing, tracing,
decode_only, decode_only,
model,
} = args; } = args;
let width = width.unwrap_or(1360); let width = width.unwrap_or(1360);
let height = height.unwrap_or(768); let height = height.unwrap_or(768);
@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let bf_repo = api.repo(hf_hub::Repo::model( let bf_repo = {
"black-forest-labs/FLUX.1-schnell".to_string(), let name = match model {
)); Model::Dev => "black-forest-labs/FLUX.1-dev",
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?; let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32(); let dtype = device.bf16_default_to_f32();
let img = match decode_only { let img = match decode_only {
@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
}; };
println!("CLIP\n{clip_emb}"); println!("CLIP\n{clip_emb}");
let img = { let img = {
let model_file = bf_repo.get("flux1-schnell.sft")?; let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
Model::Dev => bf_repo.get("flux1-dev.sft")?,
};
let vb = let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::model::Config::schnell(); let cfg = match model {
let model = flux::model::Flux::new(&cfg, vb)?; Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}"); println!("{state:?}");
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}"); println!("{timesteps:?}");
flux::sampling::denoise( flux::sampling::denoise(
&model, &model,
@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
let img = { let img = {
let model_file = bf_repo.get("ae.sft")?; let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::autoencoder::Config::schnell(); let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),
Model::Schnell => flux::autoencoder::Config::schnell(),
};
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)? model.decode(&img)?
}; };