mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add support to flan-t5 (#840)
This commit is contained in:
@ -6,6 +6,8 @@ use serde::Deserialize;
|
|||||||
pub enum Activation {
|
pub enum Activation {
|
||||||
#[default]
|
#[default]
|
||||||
Gelu,
|
Gelu,
|
||||||
|
#[serde(rename = "gated-gelu")]
|
||||||
|
NewGelu,
|
||||||
Relu,
|
Relu,
|
||||||
Elu(f64),
|
Elu(f64),
|
||||||
}
|
}
|
||||||
@ -14,6 +16,10 @@ impl super::Module for Activation {
|
|||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Gelu => xs.gelu(),
|
Self::Gelu => xs.gelu(),
|
||||||
|
// TODO: This is "gelu_new", not the original "gelu".
|
||||||
|
// There's some small numerical difference:
|
||||||
|
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||||
|
Self::NewGelu => xs.gelu(),
|
||||||
Self::Relu => xs.relu(),
|
Self::Relu => xs.relu(),
|
||||||
&Self::Elu(alpha) => xs.elu(alpha),
|
&Self::Elu(alpha) => xs.elu(alpha),
|
||||||
}
|
}
|
||||||
|
@ -148,27 +148,71 @@ impl T5DenseActDense {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct T5DenseGatedActDense {
|
||||||
|
wi_0: Linear,
|
||||||
|
wi_1: Linear,
|
||||||
|
wo: Linear,
|
||||||
|
act: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl T5DenseGatedActDense {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
||||||
|
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
||||||
|
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
|
Ok(Self {
|
||||||
|
wi_0,
|
||||||
|
wi_1,
|
||||||
|
wo,
|
||||||
|
act: Activation::NewGelu,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
|
||||||
|
let hidden_linear = self.wi_1.forward(xs)?;
|
||||||
|
let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
|
||||||
|
let xs = self.wo.forward(&xs)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct T5LayerFF {
|
struct T5LayerFF {
|
||||||
dense_relu_dense: T5DenseActDense,
|
dense_act: Option<T5DenseActDense>,
|
||||||
|
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerFF {
|
impl T5LayerFF {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
// is_gated_act is not supported.
|
|
||||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
|
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense_relu_dense,
|
dense_act,
|
||||||
|
gated_dense_act,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let ys = self.layer_norm.forward(xs)?;
|
let ys = self.layer_norm.forward(xs)?;
|
||||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
let ys = match &self.dense_act {
|
||||||
|
Some(dense_act) => dense_act.forward(&ys)?,
|
||||||
|
None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
|
||||||
|
};
|
||||||
let xs = (xs + ys)?;
|
let xs = (xs + ys)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user