[nn] Move the Embedding and Activation parts. (#116)

* Share the Embedding and Activation parts.

* Tweak some activations.
This commit is contained in:
Laurent Mazare
2023-07-10 10:24:52 +01:00
committed by GitHub
parent 9ce0f1c010
commit b06e1a7e54
9 changed files with 91 additions and 149 deletions

View File

@ -1,4 +1,6 @@
use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
use crate::nn::{
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
};
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
@ -283,7 +285,7 @@ impl MusicgenDecoder {
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| Embedding::load(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
.map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
let layers = (0..cfg.num_hidden_layers)

View File

@ -113,33 +113,16 @@ impl Dropout {
}
}
#[derive(Debug)]
pub struct Embedding {
embeddings: Tensor,
pub type Embedding = candle_nn::Embedding;
pub fn embedding(
vocab_size: usize,
hidden_size: usize,
}
impl Embedding {
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
pub fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Self::new(embeddings, hidden_size))
}
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
p: &str,
vb: &VarBuilder,
) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Embedding::new(embeddings, hidden_size))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -197,17 +180,4 @@ impl Conv1D {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HiddenAct {
Gelu,
Relu,
}
impl HiddenAct {
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
}
}
}
pub type HiddenAct = candle_nn::Activation;

View File

@ -1,7 +1,7 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::Tensor;
@ -159,7 +159,7 @@ impl T5Attention {
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let relative_attention_bias = if h {
let emb = Embedding::load(
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
&format!("{p}.relative_attention_bias"),
@ -281,7 +281,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
Ok(Self { shared, encoder })
}