mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a seed to the flux example. (#2529)
This commit is contained in:
@ -45,9 +45,13 @@ struct Args {
|
|||||||
#[arg(long, value_enum, default_value = "schnell")]
|
#[arg(long, value_enum, default_value = "schnell")]
|
||||||
model: Model,
|
model: Model,
|
||||||
|
|
||||||
/// Use the faster kernels which are buggy at the moment.
|
/// Use the slower kernels.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
no_dmmv: bool,
|
use_dmmv: bool,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long)]
|
||||||
|
seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||||
@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> {
|
|||||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
if let Some(seed) = args.seed {
|
||||||
|
device.set_seed(seed)?;
|
||||||
|
}
|
||||||
let dtype = device.bf16_default_to_f32();
|
let dtype = device.bf16_default_to_f32();
|
||||||
let img = match decode_only {
|
let img = match decode_only {
|
||||||
None => {
|
None => {
|
||||||
@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
|
candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
|
||||||
run(args)
|
run(args)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user