mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support the flux-dev model too. (#2395)
This commit is contained in:
@ -37,6 +37,15 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
decode_only: Option<String>,
|
decode_only: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, value_enum, default_value = "schnell")]
|
||||||
|
model: Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||||
|
enum Model {
|
||||||
|
Schnell,
|
||||||
|
Dev,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(args: Args) -> Result<()> {
|
fn run(args: Args) -> Result<()> {
|
||||||
@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
width,
|
width,
|
||||||
tracing,
|
tracing,
|
||||||
decode_only,
|
decode_only,
|
||||||
|
model,
|
||||||
} = args;
|
} = args;
|
||||||
let width = width.unwrap_or(1360);
|
let width = width.unwrap_or(1360);
|
||||||
let height = height.unwrap_or(768);
|
let height = height.unwrap_or(768);
|
||||||
@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let bf_repo = api.repo(hf_hub::Repo::model(
|
let bf_repo = {
|
||||||
"black-forest-labs/FLUX.1-schnell".to_string(),
|
let name = match model {
|
||||||
));
|
Model::Dev => "black-forest-labs/FLUX.1-dev",
|
||||||
|
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
|
||||||
|
};
|
||||||
|
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||||
|
};
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
let dtype = device.bf16_default_to_f32();
|
let dtype = device.bf16_default_to_f32();
|
||||||
let img = match decode_only {
|
let img = match decode_only {
|
||||||
@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("CLIP\n{clip_emb}");
|
println!("CLIP\n{clip_emb}");
|
||||||
let img = {
|
let img = {
|
||||||
let model_file = bf_repo.get("flux1-schnell.sft")?;
|
let model_file = match model {
|
||||||
|
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
|
||||||
|
Model::Dev => bf_repo.get("flux1-dev.sft")?,
|
||||||
|
};
|
||||||
let vb =
|
let vb =
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
let cfg = flux::model::Config::schnell();
|
let cfg = match model {
|
||||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
Model::Dev => flux::model::Config::dev(),
|
||||||
|
Model::Schnell => flux::model::Config::schnell(),
|
||||||
|
};
|
||||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||||
|
let timesteps = match model {
|
||||||
|
Model::Dev => {
|
||||||
|
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||||
|
}
|
||||||
|
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||||
|
};
|
||||||
|
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||||
|
|
||||||
println!("{state:?}");
|
println!("{state:?}");
|
||||||
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
|
|
||||||
println!("{timesteps:?}");
|
println!("{timesteps:?}");
|
||||||
flux::sampling::denoise(
|
flux::sampling::denoise(
|
||||||
&model,
|
&model,
|
||||||
@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let img = {
|
let img = {
|
||||||
let model_file = bf_repo.get("ae.sft")?;
|
let model_file = bf_repo.get("ae.sft")?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
let cfg = flux::autoencoder::Config::schnell();
|
let cfg = match model {
|
||||||
|
Model::Dev => flux::autoencoder::Config::dev(),
|
||||||
|
Model::Schnell => flux::autoencoder::Config::schnell(),
|
||||||
|
};
|
||||||
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
||||||
model.decode(&img)?
|
model.decode(&img)?
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user