mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example. * Add to the readme.
This commit is contained in:
@ -26,6 +26,8 @@ Check out our [examples](./candle-examples/examples/):
|
|||||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||||
generation.
|
generation.
|
||||||
|
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||||
|
image generative model, only cpu support at the moment and on the slow side.
|
||||||
|
|
||||||
Run them using the following commands:
|
Run them using the following commands:
|
||||||
```
|
```
|
||||||
@ -34,6 +36,7 @@ cargo run --example llama --release
|
|||||||
cargo run --example falcon --release
|
cargo run --example falcon --release
|
||||||
cargo run --example bert --release
|
cargo run --example bert --release
|
||||||
cargo run --example bigcode --release
|
cargo run --example bigcode --release
|
||||||
|
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||||
|
@ -45,21 +45,21 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
width: Option<usize>,
|
width: Option<usize>,
|
||||||
|
|
||||||
/// The UNet weight file, in .ot or .safetensors format.
|
/// The UNet weight file, in .safetensors format.
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
unet_weights: Option<String>,
|
unet_weights: Option<String>,
|
||||||
|
|
||||||
/// The CLIP weight file, in .ot or .safetensors format.
|
/// The CLIP weight file, in .safetensors format.
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
clip_weights: Option<String>,
|
clip_weights: Option<String>,
|
||||||
|
|
||||||
/// The VAE weight file, in .ot or .safetensors format.
|
/// The VAE weight file, in .safetensors format.
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
vae_weights: Option<String>,
|
vae_weights: Option<String>,
|
||||||
|
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
/// The file specifying the tokenizer to used for tokenization.
|
/// The file specifying the tokenizer to used for tokenization.
|
||||||
tokenizer: String,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -91,34 +91,63 @@ enum StableDiffusionVersion {
|
|||||||
V2_1,
|
V2_1,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
fn clip_weights(&self) -> String {
|
enum ModelFile {
|
||||||
match &self.clip_weights {
|
Tokenizer,
|
||||||
Some(w) => w.clone(),
|
Clip,
|
||||||
None => match self.sd_version {
|
Unet,
|
||||||
StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(),
|
Vae,
|
||||||
StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(),
|
}
|
||||||
},
|
|
||||||
|
impl StableDiffusionVersion {
|
||||||
|
fn repo(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||||
|
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vae_weights(&self) -> String {
|
fn unet_file(&self) -> &'static str {
|
||||||
match &self.vae_weights {
|
match self {
|
||||||
Some(w) => w.clone(),
|
Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
|
||||||
None => match self.sd_version {
|
|
||||||
StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(),
|
|
||||||
StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unet_weights(&self) -> String {
|
fn vae_file(&self) -> &'static str {
|
||||||
match &self.unet_weights {
|
match self {
|
||||||
Some(w) => w.clone(),
|
Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
|
||||||
None => match self.sd_version {
|
}
|
||||||
StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(),
|
}
|
||||||
StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(),
|
|
||||||
},
|
fn clip_file(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelFile {
|
||||||
|
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
|
||||||
|
const TOKENIZER_PATH: &str = "tokenizer.json";
|
||||||
|
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
filename: Option<String>,
|
||||||
|
version: StableDiffusionVersion,
|
||||||
|
) -> Result<std::path::PathBuf> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
match filename {
|
||||||
|
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||||
|
None => {
|
||||||
|
let (repo, path) = match self {
|
||||||
|
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
||||||
|
Self::Clip => (version.repo(), version.clip_file()),
|
||||||
|
Self::Unet => (version.repo(), version.unet_file()),
|
||||||
|
Self::Vae => (version.repo(), version.vae_file()),
|
||||||
|
};
|
||||||
|
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||||
|
Ok(filename)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,9 +180,6 @@ fn output_filename(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run(args: Args) -> Result<()> {
|
fn run(args: Args) -> Result<()> {
|
||||||
let clip_weights = args.clip_weights();
|
|
||||||
let vae_weights = args.vae_weights();
|
|
||||||
let unet_weights = args.unet_weights();
|
|
||||||
let Args {
|
let Args {
|
||||||
prompt,
|
prompt,
|
||||||
uncond_prompt,
|
uncond_prompt,
|
||||||
@ -166,6 +192,9 @@ fn run(args: Args) -> Result<()> {
|
|||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
num_samples,
|
num_samples,
|
||||||
sd_version,
|
sd_version,
|
||||||
|
clip_weights,
|
||||||
|
vae_weights,
|
||||||
|
unet_weights,
|
||||||
..
|
..
|
||||||
} = args;
|
} = args;
|
||||||
let sd_config = match sd_version {
|
let sd_config = match sd_version {
|
||||||
@ -180,6 +209,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
|
||||||
|
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
let pad_id = match &sd_config.clip.pad_with {
|
let pad_id = match &sd_config.clip.pad_with {
|
||||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||||
@ -207,14 +237,17 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
println!("Building the Clip transformer.");
|
||||||
|
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
|
||||||
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
|
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
|
||||||
let text_embeddings = text_model.forward(&tokens)?;
|
let text_embeddings = text_model.forward(&tokens)?;
|
||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
|
||||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
|
||||||
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
|
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
|
||||||
|
|
||||||
let bsize = 1;
|
let bsize = 1;
|
||||||
|
@ -172,7 +172,11 @@ impl StableDiffusionConfig {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
|
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
vae_weights: P,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<vae::AutoEncoderKL> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||||
@ -181,9 +185,9 @@ impl StableDiffusionConfig {
|
|||||||
Ok(autoencoder)
|
Ok(autoencoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_unet(
|
pub fn build_unet<P: AsRef<std::path::Path>>(
|
||||||
&self,
|
&self,
|
||||||
unet_weights: &str,
|
unet_weights: P,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||||
@ -198,9 +202,9 @@ impl StableDiffusionConfig {
|
|||||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_clip_transformer(
|
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||||
&self,
|
&self,
|
||||||
clip_weights: &str,
|
clip_weights: P,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<clip::ClipTextTransformer> {
|
) -> Result<clip::ClipTextTransformer> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||||
|
Reference in New Issue
Block a user