mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
[nn] Move the Embedding and Activation parts. (#116)
* Share the Embedding and Activation parts. * Tweak some activations.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
@ -159,7 +159,7 @@ impl T5Attention {
|
||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = Embedding::load(
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
&format!("{p}.relative_attention_bias"),
|
||||
@ -281,7 +281,7 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
|
Reference in New Issue
Block a user