mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add support to UL2 model family (#1300)
* Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
// T5 Text Model, quantized version
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::Embedding;
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
@ -54,8 +55,8 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||
pub feed_forward_proj: ActivationWithOptionalGating,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
@ -83,7 +84,10 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
act: cfg.feed_forward_proj.activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||
})
|
||||
}
|
||||
@ -205,7 +209,7 @@ impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
|
Reference in New Issue
Block a user