[nn] Move the Embedding and Activation parts. (#116)

* Share the Embedding and Activation parts.

* Tweak some activations.
This commit is contained in:
Laurent Mazare
2023-07-10 10:24:52 +01:00
committed by GitHub
parent 9ce0f1c010
commit b06e1a7e54
9 changed files with 91 additions and 149 deletions

View File

@ -1,7 +1,7 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::Tensor;
@ -159,7 +159,7 @@ impl T5Attention {
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let relative_attention_bias = if h {
let emb = Embedding::load(
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
&format!("{p}.relative_attention_bias"),
@ -281,7 +281,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
Ok(Self { shared, encoder })
}