mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the forward pass for the T5 model. (#152)
* Add the forward pass for the T5 model. * More t5 forward pass.
This commit is contained in:
@ -3,7 +3,8 @@
|
||||
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
use candle::{DType, Tensor, D};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
@ -92,6 +93,18 @@ impl T5LayerNorm {
|
||||
variance_epsilon: eps,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||
let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
|
||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -114,6 +127,14 @@ impl T5DenseActDense {
|
||||
act: HiddenAct::Relu,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.wi.forward(xs)?;
|
||||
let xs = self.act.forward(&xs)?;
|
||||
let xs = self.dropout.forward(&xs)?;
|
||||
let xs = self.wo.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -136,6 +157,13 @@ impl T5LayerFF {
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.layer_norm.forward(xs)?;
|
||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||
let xs = (xs + self.dropout.forward(&ys)?)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -144,6 +172,8 @@ struct T5Attention {
|
||||
k: Linear,
|
||||
v: Linear,
|
||||
o: Linear,
|
||||
n_heads: usize,
|
||||
d_kv: usize,
|
||||
relative_attention_bias: Option<Embedding>,
|
||||
}
|
||||
|
||||
@ -169,9 +199,33 @@ impl T5Attention {
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
n_heads: cfg.num_heads,
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
let q = self.q.forward(xs)?;
|
||||
let k = self.k.forward(xs)?;
|
||||
let v = self.v.forward(xs)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// position_bias_masked
|
||||
let attn_weights = scores.softmax(D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -193,6 +247,13 @@ impl T5LayerSelfAttention {
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let normed_xs = self.layer_norm.forward(xs)?;
|
||||
let ys = self.self_attention.forward(&normed_xs)?;
|
||||
let ys = (xs + ys)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -202,6 +263,10 @@ impl T5LayerCrossAttention {
|
||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -228,18 +293,31 @@ impl T5Block {
|
||||
ff,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.self_attn.forward(xs)?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
xs = cross_attn.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
}
|
||||
let xs = self.ff.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Stack {
|
||||
// TODO: Add embed_tokens if needed (shared embedding layer).
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -251,22 +329,42 @@ impl T5Stack {
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
block,
|
||||
shared: shared.clone(),
|
||||
final_layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
|
||||
|
||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||
for block in self.block.iter() {
|
||||
hidden_states = block.forward(&hidden_states)?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct T5EncoderModel {
|
||||
shared: Embedding,
|
||||
shared: Arc<Embedding>,
|
||||
encoder: T5Stack,
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_outputs = self.encoder.forward(input_ids)?;
|
||||
Ok(encoder_outputs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user