mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Override the repo for SDXL f16 vae weights. (#1064)
* Override the repo for SDXL f16 vae weights. * Slightly simpler change.
This commit is contained in:
@ -97,7 +97,7 @@ struct Args {
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
@ -204,7 +204,18 @@ impl ModelFile {
|
||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||
Self::Vae => {
|
||||
// Override for SDXL when using f16 weights.
|
||||
// See https://github.com/huggingface/candle/issues/1060
|
||||
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||
(
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
} else {
|
||||
(version.repo(), version.vae_file(use_f16))
|
||||
}
|
||||
}
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
|
Reference in New Issue
Block a user