Override the repo for SDXL f16 vae weights. (#1064)

* Override the repo for SDXL f16 vae weights.

* Slightly simpler change.
This commit is contained in:
Laurent Mazare
2023-10-09 06:52:28 +01:00
committed by GitHub
parent 392fe02fba
commit 4d04ac83c7

View File

@ -97,7 +97,7 @@ struct Args {
img2img_strength: f64,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum StableDiffusionVersion {
V1_5,
V2_1,
@ -204,7 +204,18 @@ impl ModelFile {
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
if version == StableDiffusionVersion::Xl && use_f16 {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
)
} else {
(version.repo(), version.vae_file(use_f16))
}
}
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)