mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00

* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
202 lines
6.6 KiB
Rust
202 lines
6.6 KiB
Rust
use anyhow::{Error as E, Ok, Result};
|
|
use candle::{DType, IndexOp, Module, Tensor, D};
|
|
use candle_transformers::models::{stable_diffusion, t5};
|
|
use tokenizers::tokenizer::Tokenizer;
|
|
|
|
struct ClipWithTokenizer {
|
|
clip: stable_diffusion::clip::ClipTextTransformer,
|
|
config: stable_diffusion::clip::Config,
|
|
tokenizer: Tokenizer,
|
|
max_position_embeddings: usize,
|
|
}
|
|
|
|
impl ClipWithTokenizer {
|
|
fn new(
|
|
vb: candle_nn::VarBuilder,
|
|
config: stable_diffusion::clip::Config,
|
|
tokenizer_path: &str,
|
|
max_position_embeddings: usize,
|
|
) -> Result<Self> {
|
|
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
|
|
let path_buf = hf_hub::api::sync::Api::new()?
|
|
.model(tokenizer_path.to_string())
|
|
.get("tokenizer.json")?;
|
|
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
|
|
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
|
|
))?)
|
|
.map_err(E::msg)?;
|
|
Ok(Self {
|
|
clip,
|
|
config,
|
|
tokenizer,
|
|
max_position_embeddings,
|
|
})
|
|
}
|
|
|
|
fn encode_text_to_embedding(
|
|
&self,
|
|
prompt: &str,
|
|
device: &candle::Device,
|
|
) -> Result<(Tensor, Tensor)> {
|
|
let pad_id = match &self.config.pad_with {
|
|
Some(padding) => *self
|
|
.tokenizer
|
|
.get_vocab(true)
|
|
.get(padding.as_str())
|
|
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
|
|
None => *self
|
|
.tokenizer
|
|
.get_vocab(true)
|
|
.get("<|endoftext|>")
|
|
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
|
|
};
|
|
|
|
let mut tokens = self
|
|
.tokenizer
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
|
|
let eos_position = tokens.len() - 1;
|
|
|
|
while tokens.len() < self.max_position_embeddings {
|
|
tokens.push(pad_id)
|
|
}
|
|
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
|
let (text_embeddings, text_embeddings_penultimate) = self
|
|
.clip
|
|
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
|
|
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
|
|
|
|
Ok((text_embeddings_penultimate, text_embeddings_pooled))
|
|
}
|
|
}
|
|
|
|
struct T5WithTokenizer {
|
|
t5: t5::T5EncoderModel,
|
|
tokenizer: Tokenizer,
|
|
max_position_embeddings: usize,
|
|
}
|
|
|
|
impl T5WithTokenizer {
|
|
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
|
|
let api = hf_hub::api::sync::Api::new()?;
|
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
|
"google/t5-v1_1-xxl".to_string(),
|
|
hf_hub::RepoType::Model,
|
|
"refs/pr/2".to_string(),
|
|
));
|
|
let config_filename = repo.get("config.json")?;
|
|
let config = std::fs::read_to_string(config_filename)?;
|
|
let config: t5::Config = serde_json::from_str(&config)?;
|
|
let model = t5::T5EncoderModel::load(vb, &config)?;
|
|
|
|
let tokenizer_filename = api
|
|
.model("lmz/mt5-tokenizers".to_string())
|
|
.get("t5-v1_1-xxl.tokenizer.json")?;
|
|
|
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
Ok(Self {
|
|
t5: model,
|
|
tokenizer,
|
|
max_position_embeddings,
|
|
})
|
|
}
|
|
|
|
fn encode_text_to_embedding(
|
|
&mut self,
|
|
prompt: &str,
|
|
device: &candle::Device,
|
|
) -> Result<Tensor> {
|
|
let mut tokens = self
|
|
.tokenizer
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
tokens.resize(self.max_position_embeddings, 0);
|
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
|
let embeddings = self.t5.forward(&input_token_ids)?;
|
|
Ok(embeddings)
|
|
}
|
|
}
|
|
|
|
pub struct StableDiffusion3TripleClipWithTokenizer {
|
|
clip_l: ClipWithTokenizer,
|
|
clip_g: ClipWithTokenizer,
|
|
clip_g_text_projection: candle_nn::Linear,
|
|
t5: T5WithTokenizer,
|
|
}
|
|
|
|
impl StableDiffusion3TripleClipWithTokenizer {
|
|
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
|
|
let max_position_embeddings = 77usize;
|
|
let clip_l = ClipWithTokenizer::new(
|
|
vb_fp16.pp("clip_l.transformer"),
|
|
stable_diffusion::clip::Config::sdxl(),
|
|
"openai/clip-vit-large-patch14",
|
|
max_position_embeddings,
|
|
)?;
|
|
|
|
let clip_g = ClipWithTokenizer::new(
|
|
vb_fp16.pp("clip_g.transformer"),
|
|
stable_diffusion::clip::Config::sdxl2(),
|
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
max_position_embeddings,
|
|
)?;
|
|
|
|
let text_projection = candle_nn::linear_no_bias(
|
|
1280,
|
|
1280,
|
|
vb_fp16.pp("clip_g.transformer.text_projection"),
|
|
)?;
|
|
|
|
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
|
|
// This is a temporary workaround until the T5 implementation is updated to support fp16.
|
|
// Also see:
|
|
// https://github.com/huggingface/candle/issues/2480
|
|
// https://github.com/huggingface/candle/pull/2481
|
|
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
|
|
|
|
Ok(Self {
|
|
clip_l,
|
|
clip_g,
|
|
clip_g_text_projection: text_projection,
|
|
t5,
|
|
})
|
|
}
|
|
|
|
pub fn encode_text_to_embedding(
|
|
&mut self,
|
|
prompt: &str,
|
|
device: &candle::Device,
|
|
) -> Result<(Tensor, Tensor)> {
|
|
let (clip_l_embeddings, clip_l_embeddings_pooled) =
|
|
self.clip_l.encode_text_to_embedding(prompt, device)?;
|
|
let (clip_g_embeddings, clip_g_embeddings_pooled) =
|
|
self.clip_g.encode_text_to_embedding(prompt, device)?;
|
|
|
|
let clip_g_embeddings_pooled = self
|
|
.clip_g_text_projection
|
|
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
|
|
.squeeze(0)?;
|
|
|
|
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
|
|
.unsqueeze(0)?;
|
|
let clip_embeddings_concat = Tensor::cat(
|
|
&[&clip_l_embeddings, &clip_g_embeddings],
|
|
D::Minus1,
|
|
)?
|
|
.pad_with_zeros(D::Minus1, 0, 2048)?;
|
|
|
|
let t5_embeddings = self
|
|
.t5
|
|
.encode_text_to_embedding(prompt, device)?
|
|
.to_dtype(DType::F16)?;
|
|
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
|
|
|
|
Ok((context, y))
|
|
}
|
|
}
|