mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Lazy upcasting for t5. (#2589)
This commit is contained in:
@ -1,12 +1,38 @@
|
||||
// T5 Text Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
|
||||
use crate::models::with_tracing::Embedding;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||
let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { weight, span })
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let weight = self.weight.to_dtype(xs.dtype())?;
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => weight.broadcast_left(bsize)?.t()?,
|
||||
_ => weight.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_relative_attention_max_distance() -> usize {
|
||||
128
|
||||
}
|
||||
@ -185,7 +211,7 @@ impl Module for T5LayerNorm {
|
||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
@ -472,7 +498,8 @@ impl T5Attention {
|
||||
let position_bias = relative_attention_bias
|
||||
.forward(&relative_buckets)?
|
||||
.permute((2, 0, 1))?
|
||||
.unsqueeze(0)?;
|
||||
.unsqueeze(0)?
|
||||
.to_dtype(scores.dtype())?;
|
||||
(scores.broadcast_add(&position_bias)?, Some(position_bias))
|
||||
// TODO: position_bias_masked?
|
||||
}
|
||||
@ -678,9 +705,22 @@ impl T5Stack {
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
self.forward_dt(input_ids, encoder_hidden_states, None)
|
||||
}
|
||||
|
||||
fn forward_dt(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
dtype: Option<DType>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let input_embeds = match dtype {
|
||||
None => input_embeds,
|
||||
Some(dtype) => input_embeds.to_dtype(dtype)?,
|
||||
};
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter_mut() {
|
||||
@ -729,6 +769,11 @@ impl T5EncoderModel {
|
||||
self.encoder.forward(input_ids, None)
|
||||
}
|
||||
|
||||
pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.encoder.forward_dt(input_ids, None, dtype)
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
Reference in New Issue
Block a user