mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Override the repo for SDXL f16 vae weights. (#1064)
* Override the repo for SDXL f16 vae weights. * Slightly simpler change.
This commit is contained in:
@ -97,7 +97,7 @@ struct Args {
|
|||||||
img2img_strength: f64,
|
img2img_strength: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||||
enum StableDiffusionVersion {
|
enum StableDiffusionVersion {
|
||||||
V1_5,
|
V1_5,
|
||||||
V2_1,
|
V2_1,
|
||||||
@ -204,7 +204,18 @@ impl ModelFile {
|
|||||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
Self::Vae => {
|
||||||
|
// Override for SDXL when using f16 weights.
|
||||||
|
// See https://github.com/huggingface/candle/issues/1060
|
||||||
|
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||||
|
(
|
||||||
|
"madebyollin/sdxl-vae-fp16-fix",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(version.repo(), version.vae_file(use_f16))
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||||
Ok(filename)
|
Ok(filename)
|
||||||
|
Reference in New Issue
Block a user