mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Lazy upcasting for t5. (#2589)
This commit is contained in:
@ -118,7 +118,7 @@ impl T5WithTokenizer {
|
||||
.to_vec();
|
||||
tokens.resize(self.max_position_embeddings, 0);
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let embeddings = self.t5.forward(&input_token_ids)?;
|
||||
let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
@ -144,7 +144,7 @@ impl StableDiffusion3TripleClipWithTokenizer {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
|
||||
};
|
||||
let vb_t5 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
|
||||
};
|
||||
let max_position_embeddings = 77usize;
|
||||
let clip_l = ClipWithTokenizer::new(
|
||||
@ -164,11 +164,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
|
||||
// This is a temporary workaround until the T5 implementation is updated to support fp16.
|
||||
// Also see:
|
||||
// https://github.com/huggingface/candle/issues/2480
|
||||
// https://github.com/huggingface/candle/pull/2481
|
||||
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
|
||||
Ok(Self {
|
||||
clip_l,
|
||||
@ -178,34 +173,26 @@ impl StableDiffusion3TripleClipWithTokenizer {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
|
||||
let max_position_embeddings = 77usize;
|
||||
let clip_l = ClipWithTokenizer::new(
|
||||
vb_fp16.pp("clip_l.transformer"),
|
||||
vb.pp("clip_l.transformer"),
|
||||
stable_diffusion::clip::Config::sdxl(),
|
||||
"openai/clip-vit-large-patch14",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
let clip_g = ClipWithTokenizer::new(
|
||||
vb_fp16.pp("clip_g.transformer"),
|
||||
vb.pp("clip_g.transformer"),
|
||||
stable_diffusion::clip::Config::sdxl2(),
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
let text_projection = candle_nn::linear_no_bias(
|
||||
1280,
|
||||
1280,
|
||||
vb_fp16.pp("clip_g.transformer.text_projection"),
|
||||
)?;
|
||||
let text_projection =
|
||||
candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
|
||||
|
||||
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
|
||||
// This is a temporary workaround until the T5 implementation is updated to support fp16.
|
||||
// Also see:
|
||||
// https://github.com/huggingface/candle/issues/2480
|
||||
// https://github.com/huggingface/candle/pull/2481
|
||||
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
|
||||
let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
|
||||
Ok(Self {
|
||||
clip_l,
|
||||
clip_g,
|
||||
|
Reference in New Issue
Block a user