Lazy upcasting for t5. (#2589)

This commit is contained in:
Laurent Mazare
2024-10-30 18:08:51 +01:00
committed by GitHub
parent d232e132f6
commit 7ac0de15a9
3 changed files with 59 additions and 34 deletions

View File

@ -118,7 +118,7 @@ impl T5WithTokenizer {
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward(&input_token_ids)?;
let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
Ok(embeddings)
}
}
@ -144,7 +144,7 @@ impl StableDiffusion3TripleClipWithTokenizer {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
@ -164,11 +164,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
max_position_embeddings,
)?;
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
@ -178,34 +173,26 @@ impl StableDiffusion3TripleClipWithTokenizer {
})
}
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_fp16.pp("clip_l.transformer"),
vb.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb_fp16.pp("clip_g.transformer"),
vb.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection = candle_nn::linear_no_bias(
1280,
1280,
vb_fp16.pp("clip_g.transformer.text_projection"),
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,

View File

@ -194,18 +194,11 @@ fn main() -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"),
vb_fp32.pp("text_encoders"),
)?;
(MMDiTConfig::sd3_medium(), triple, vb_fp16)
let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?;
(MMDiTConfig::sd3_medium(), triple, vb)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =