Stable diffusion: retrieve the model files from the HF hub. (#414)

* Retrieve the model files from the HF hub in the stable diffusion example.

* Add to the readme.
This commit is contained in:
Laurent Mazare
2023-08-11 19:57:06 +02:00
committed by GitHub
parent 91dbf907d3
commit 1d0157bbc4
3 changed files with 74 additions and 34 deletions

View File

@ -172,7 +172,11 @@ impl StableDiffusionConfig {
)
}
pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
device: &Device,
) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
@ -181,9 +185,9 @@ impl StableDiffusionConfig {
Ok(autoencoder)
}
pub fn build_unet(
pub fn build_unet<P: AsRef<std::path::Path>>(
&self,
unet_weights: &str,
unet_weights: P,
device: &Device,
in_channels: usize,
) -> Result<unet_2d::UNet2DConditionModel> {
@ -198,9 +202,9 @@ impl StableDiffusionConfig {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
pub fn build_clip_transformer(
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
clip_weights: &str,
clip_weights: P,
device: &Device,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };