mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the forward pass for the T5 model. (#152)
* Add the forward pass for the T5 model. * More t5 forward pass.
This commit is contained in:
@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::Tensor;
|
use candle::{DType, Tensor, D};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@ -92,6 +93,18 @@ impl T5LayerNorm {
|
|||||||
variance_epsilon: eps,
|
variance_epsilon: eps,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let dtype = xs.dtype();
|
||||||
|
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||||
|
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||||
|
let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
|
||||||
|
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||||
|
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||||
|
let xs = xs.to_dtype(dtype)?;
|
||||||
|
let xs = xs.broadcast_mul(&self.weight)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -114,6 +127,14 @@ impl T5DenseActDense {
|
|||||||
act: HiddenAct::Relu,
|
act: HiddenAct::Relu,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.wi.forward(xs)?;
|
||||||
|
let xs = self.act.forward(&xs)?;
|
||||||
|
let xs = self.dropout.forward(&xs)?;
|
||||||
|
let xs = self.wo.forward(&xs)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -136,6 +157,13 @@ impl T5LayerFF {
|
|||||||
dropout,
|
dropout,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let ys = self.layer_norm.forward(xs)?;
|
||||||
|
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||||
|
let xs = (xs + self.dropout.forward(&ys)?)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -144,6 +172,8 @@ struct T5Attention {
|
|||||||
k: Linear,
|
k: Linear,
|
||||||
v: Linear,
|
v: Linear,
|
||||||
o: Linear,
|
o: Linear,
|
||||||
|
n_heads: usize,
|
||||||
|
d_kv: usize,
|
||||||
relative_attention_bias: Option<Embedding>,
|
relative_attention_bias: Option<Embedding>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,9 +199,33 @@ impl T5Attention {
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
o,
|
o,
|
||||||
|
n_heads: cfg.num_heads,
|
||||||
|
d_kv: cfg.d_kv,
|
||||||
relative_attention_bias,
|
relative_attention_bias,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||||
|
let q = self.q.forward(xs)?;
|
||||||
|
let k = self.k.forward(xs)?;
|
||||||
|
let v = self.v.forward(xs)?;
|
||||||
|
let q = q
|
||||||
|
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let scores = q.matmul(&k.t()?)?;
|
||||||
|
// position_bias_masked
|
||||||
|
let attn_weights = scores.softmax(D::Minus1)?;
|
||||||
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
|
let attn_output = self.o.forward(&attn_output)?;
|
||||||
|
Ok(attn_output)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -193,6 +247,13 @@ impl T5LayerSelfAttention {
|
|||||||
dropout,
|
dropout,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let normed_xs = self.layer_norm.forward(xs)?;
|
||||||
|
let ys = self.self_attention.forward(&normed_xs)?;
|
||||||
|
let ys = (xs + ys)?;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -202,6 +263,10 @@ impl T5LayerCrossAttention {
|
|||||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -228,18 +293,31 @@ impl T5Block {
|
|||||||
ff,
|
ff,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = self.self_attn.forward(xs)?;
|
||||||
|
// TODO: clamp for f16?
|
||||||
|
if let Some(cross_attn) = &self.cross_attn {
|
||||||
|
xs = cross_attn.forward(&xs)?;
|
||||||
|
// TODO: clamp for f16?
|
||||||
|
}
|
||||||
|
let xs = self.ff.forward(&xs)?;
|
||||||
|
// TODO: clamp for f16?
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct T5Stack {
|
struct T5Stack {
|
||||||
// TODO: Add embed_tokens if needed (shared embedding layer).
|
// TODO: Add embed_tokens if needed (shared embedding layer).
|
||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
|
shared: Arc<Embedding>,
|
||||||
final_layer_norm: T5LayerNorm,
|
final_layer_norm: T5LayerNorm,
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Stack {
|
impl T5Stack {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||||
let block = (0..cfg.num_layers)
|
let block = (0..cfg.num_layers)
|
||||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
@ -251,22 +329,42 @@ impl T5Stack {
|
|||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
block,
|
block,
|
||||||
|
shared: shared.clone(),
|
||||||
final_layer_norm,
|
final_layer_norm,
|
||||||
dropout,
|
dropout,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
|
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
|
||||||
|
|
||||||
|
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||||
|
for block in self.block.iter() {
|
||||||
|
hidden_states = block.forward(&hidden_states)?
|
||||||
|
}
|
||||||
|
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||||
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
|
Ok(hidden_states)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct T5EncoderModel {
|
pub struct T5EncoderModel {
|
||||||
shared: Embedding,
|
shared: Arc<Embedding>,
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
|
let shared = Arc::new(shared);
|
||||||
|
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self { shared, encoder })
|
Ok(Self { shared, encoder })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let encoder_outputs = self.encoder.forward(input_ids)?;
|
||||||
|
Ok(encoder_outputs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user