Files
candle/candle-examples/examples/stable-diffusion-3/clip.rs
Czxck001 ca7cf5cb3b Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-13 22:08:40 +02:00

202 lines
6.6 KiB
Rust

use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let eos_position = tokens.len() - 1;
while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}
struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward(&input_token_ids)?;
Ok(embeddings)
}
}
pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}
impl StableDiffusion3TripleClipWithTokenizer {
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_fp16.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb_fp16.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection = candle_nn::linear_no_bias(
1280,
1280,
vb_fp16.pp("clip_g.transformer.text_projection"),
)?;
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;
let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y))
}
}