mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Preliminary support for SDXL. (#647)
* Preliminary support for SDXL. * More SDXL support. * More SDXL. * Use the proper clip config. * Querying for existing tensors. * More robust test.
This commit is contained in:
@ -17,7 +17,7 @@ mod utils;
|
||||
mod vae;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Tensor, D};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -102,12 +102,16 @@ struct Args {
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
Xl,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
Tokenizer2,
|
||||
Clip,
|
||||
Clip2,
|
||||
Unet,
|
||||
Vae,
|
||||
}
|
||||
@ -115,6 +119,7 @@ enum ModelFile {
|
||||
impl StableDiffusionVersion {
|
||||
fn repo(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||
}
|
||||
@ -122,7 +127,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
@ -134,7 +139,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
@ -146,7 +151,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"text_encoder/model.fp16.safetensors"
|
||||
} else {
|
||||
@ -155,12 +160,21 @@ impl StableDiffusionVersion {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"text_encoder_2/model.fp16.safetensors"
|
||||
} else {
|
||||
"text_encoder_2/model.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelFile {
|
||||
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
|
||||
const TOKENIZER_PATH: &str = "tokenizer.json";
|
||||
|
||||
fn get(
|
||||
&self,
|
||||
filename: Option<String>,
|
||||
@ -172,8 +186,24 @@ impl ModelFile {
|
||||
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||
None => {
|
||||
let (repo, path) = match self {
|
||||
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
||||
Self::Tokenizer => {
|
||||
let tokenizer_repo = match version {
|
||||
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
||||
"openai/clip-vit-base-patch32"
|
||||
}
|
||||
StableDiffusionVersion::Xl => {
|
||||
// This seems similar to the patch32 version except some very small
|
||||
// difference in the split regex.
|
||||
"openai/clip-vit-large-patch14"
|
||||
}
|
||||
};
|
||||
(tokenizer_repo, "tokenizer.json")
|
||||
}
|
||||
Self::Tokenizer2 => {
|
||||
("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
|
||||
}
|
||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||
};
|
||||
@ -211,6 +241,71 @@ fn output_filename(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn text_embeddings(
|
||||
prompt: &str,
|
||||
uncond_prompt: &str,
|
||||
tokenizer: Option<String>,
|
||||
clip_weights: Option<String>,
|
||||
sd_version: StableDiffusionVersion,
|
||||
sd_config: &stable_diffusion::StableDiffusionConfig,
|
||||
use_f16: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
first: bool,
|
||||
) -> Result<Tensor> {
|
||||
let tokenizer_file = if first {
|
||||
ModelFile::Tokenizer
|
||||
} else {
|
||||
ModelFile::Tokenizer2
|
||||
};
|
||||
let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &sd_config.clip.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let clip_weights_file = if first {
|
||||
ModelFile::Clip
|
||||
} else {
|
||||
ModelFile::Clip2
|
||||
};
|
||||
let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
|
||||
let clip_config = if first {
|
||||
&sd_config.clip
|
||||
} else {
|
||||
sd_config.clip2.as_ref().unwrap()
|
||||
};
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
|
||||
StableDiffusionVersion::V2_1 => {
|
||||
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
|
||||
}
|
||||
StableDiffusionVersion::Xl => {
|
||||
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
||||
}
|
||||
};
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &sd_config.clip.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let text_embeddings = {
|
||||
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
|
||||
let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
|
||||
let which = match sd_version {
|
||||
StableDiffusionVersion::Xl => vec![true, false],
|
||||
_ => vec![true],
|
||||
};
|
||||
let text_embeddings = which
|
||||
.iter()
|
||||
.map(|first| {
|
||||
text_embeddings(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
tokenizer.clone(),
|
||||
clip_weights.clone(),
|
||||
sd_version,
|
||||
&sd_config,
|
||||
use_f16,
|
||||
&device,
|
||||
dtype,
|
||||
*first,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
|
Reference in New Issue
Block a user