[nn] Move the Embedding and Activation parts. (#116)

* Share the Embedding and Activation parts.

* Tweak some activations.
This commit is contained in:
Laurent Mazare
2023-07-10 10:24:52 +01:00
committed by GitHub
parent 9ce0f1c010
commit b06e1a7e54
9 changed files with 91 additions and 149 deletions

View File

@ -1,6 +1,6 @@
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
use candle_nn::{LayerNorm, Linear};
use candle_nn::{Embedding, LayerNorm, Linear};
use std::collections::HashMap;
const MAX_SEQ_LEN: usize = 5000;
@ -108,33 +108,9 @@ impl Dropout {
}
}
#[derive(Debug)]
struct Embedding {
embeddings: Tensor,
hidden_size: usize,
}
impl Embedding {
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Self::new(embeddings, hidden_size))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Embedding::new(embeddings, hidden_size))
}
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
@ -563,7 +539,7 @@ impl Falcon {
}
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
let word_embeddings = Embedding::load(
let word_embeddings = embedding(
cfg.vocab_size,
cfg.hidden_size,
"transformer.word_embeddings",