mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example. * Add to the readme.
This commit is contained in:
@ -172,7 +172,11 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
device: &Device,
|
||||
) -> Result<vae::AutoEncoderKL> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
@ -181,9 +185,9 @@ impl StableDiffusionConfig {
|
||||
Ok(autoencoder)
|
||||
}
|
||||
|
||||
pub fn build_unet(
|
||||
pub fn build_unet<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
unet_weights: &str,
|
||||
unet_weights: P,
|
||||
device: &Device,
|
||||
in_channels: usize,
|
||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||
@ -198,9 +202,9 @@ impl StableDiffusionConfig {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
}
|
||||
|
||||
pub fn build_clip_transformer(
|
||||
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
clip_weights: &str,
|
||||
clip_weights: P,
|
||||
device: &Device,
|
||||
) -> Result<clip::ClipTextTransformer> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||
|
Reference in New Issue
Block a user