Add a seed to the flux example. (#2529)

This commit is contained in:
Laurent Mazare
2024-10-02 10:52:02 +02:00
committed by GitHub
parent fd08d3d0a4
commit f479840ce6

View File

@ -45,9 +45,13 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")] #[arg(long, value_enum, default_value = "schnell")]
model: Model, model: Model,
/// Use the faster kernels which are buggy at the moment. /// Use the slower kernels.
#[arg(long)] #[arg(long)]
no_dmmv: bool, use_dmmv: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
} }
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string())) api.repo(hf_hub::Repo::model(name.to_string()))
}; };
let device = candle_examples::device(cpu)?; let device = candle_examples::device(cpu)?;
if let Some(seed) = args.seed {
device.set_seed(seed)?;
}
let dtype = device.bf16_default_to_f32(); let dtype = device.bf16_default_to_f32();
let img = match decode_only { let img = match decode_only {
None => { None => {
@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv); candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args) run(args)
} }