diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 6a08d9c8..74e816d4 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -97,7 +97,7 @@ struct Args { img2img_strength: f64, } -#[derive(Debug, Clone, Copy, clap::ValueEnum)] +#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] enum StableDiffusionVersion { V1_5, V2_1, @@ -204,7 +204,18 @@ impl ModelFile { Self::Clip => (version.repo(), version.clip_file(use_f16)), Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), Self::Unet => (version.repo(), version.unet_file(use_f16)), - Self::Vae => (version.repo(), version.vae_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if version == StableDiffusionVersion::Xl && use_f16 { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (version.repo(), version.vae_file(use_f16)) + } + } }; let filename = Api::new()?.model(repo.to_string()).get(path)?; Ok(filename)