[nn] Move the Embedding and Activation parts. (#116)

* Share the Embedding and Activation parts.

* Tweak some activations.
This commit is contained in:
Laurent Mazare
2023-07-10 10:24:52 +01:00
committed by GitHub
parent 9ce0f1c010
commit b06e1a7e54
9 changed files with 91 additions and 149 deletions

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_nn::{LayerNorm, Linear};
use candle_nn::{Embedding, LayerNorm, Linear};
use serde::Deserialize;
use std::collections::HashMap;
@ -63,21 +63,6 @@ impl<'a> VarBuilder<'a> {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HiddenAct {
Gelu,
Relu,
}
impl HiddenAct {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
}
}
}
// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
@ -111,32 +96,9 @@ impl Config {
}
}
struct Embedding {
embeddings: Tensor,
hidden_size: usize,
}
impl Embedding {
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Self::new(embeddings, hidden_size))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
@ -449,8 +411,7 @@ impl TextDecoder {
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding =
Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
let positional_embedding =
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
let blocks = (0..cfg.decoder_layers)
@ -483,7 +444,10 @@ impl TextDecoder {
x = block.forward(&x, Some(xa), Some(&self.mask))?;
}
let x = self.ln.forward(&x)?;
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits)
}