mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add support to UL2 model family (#1300)
* Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
// T5 Text Model, quantized version
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::Embedding;
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
@ -54,8 +55,8 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||
pub feed_forward_proj: ActivationWithOptionalGating,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
@ -83,7 +84,10 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
act: cfg.feed_forward_proj.activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||
})
|
||||
}
|
||||
@ -205,7 +209,7 @@ impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
|
@ -37,6 +37,37 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
|
||||
pub struct ActivationWithOptionalGating {
|
||||
pub gated: bool,
|
||||
pub activation: candle_nn::Activation,
|
||||
}
|
||||
|
||||
pub fn deserialize_feed_forward_proj_activation<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
match String::deserialize(deserializer)?.as_str() {
|
||||
"gated-gelu" => Ok(ActivationWithOptionalGating {
|
||||
gated: true,
|
||||
activation: candle_nn::Activation::NewGelu,
|
||||
}),
|
||||
"gated-silu" => Ok(ActivationWithOptionalGating {
|
||||
gated: true,
|
||||
activation: candle_nn::Activation::Silu,
|
||||
}),
|
||||
buf => {
|
||||
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
|
||||
Ok(ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
@ -52,8 +83,8 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||
feed_forward_proj: ActivationWithOptionalGating,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
@ -81,7 +112,10 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -102,7 +136,10 @@ impl Config {
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
@ -202,7 +239,7 @@ impl T5DenseGatedActDense {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
act: cfg.feed_forward_proj.activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||
})
|
||||
}
|
||||
@ -231,7 +268,7 @@ impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
|
Reference in New Issue
Block a user