diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 8c3ca2ee..14642e9a 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -96,6 +96,10 @@ struct Args { /// information. #[arg(long, default_value_t = 0.8)] img2img_strength: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -374,6 +378,7 @@ fn run(args: Args) -> Result<()> { use_flash_attn, img2img, img2img_strength, + seed, .. } = args; @@ -427,6 +432,9 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + if let Some(seed) = seed { + device.set_seed(seed)?; + } let use_guide_scale = guidance_scale > 1.0; let which = match sd_version {