From b06e1a7e54403478b72237cb0d6d5aaddc88e132 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 10:24:52 +0100 Subject: [PATCH] [nn] Move the Embedding and Activation parts. (#116) * Share the Embedding and Activation parts. * Tweak some activations. --- candle-examples/examples/bert/main.rs | 37 +++---------- candle-examples/examples/falcon/model.rs | 34 ++---------- .../examples/musicgen/musicgen_model.rs | 6 ++- candle-examples/examples/musicgen/nn.rs | 50 ++++------------- candle-examples/examples/musicgen/t5_model.rs | 6 +-- candle-examples/examples/whisper/model.rs | 54 ++++--------------- candle-nn/src/activation.rs | 18 +++++++ candle-nn/src/embedding.rs | 29 ++++++++++ candle-nn/src/lib.rs | 6 +++ 9 files changed, 91 insertions(+), 149 deletions(-) create mode 100644 candle-nn/src/activation.rs create mode 100644 candle-nn/src/embedding.rs diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index b2f92bbc..d4830495 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; -use candle_nn::{LayerNorm, Linear}; +use candle_nn::{Embedding, LayerNorm, Linear}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; @@ -167,32 +167,9 @@ impl Config { } } -struct Embedding { - embeddings: Tensor, - hidden_size: usize, -} - -impl Embedding { - fn new(embeddings: Tensor, hidden_size: usize) -> Self { - Self { - embeddings, - hidden_size, - } - } - - fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; - Ok(Self::new(embeddings, hidden_size)) - } - - fn forward(&self, indexes: &Tensor) -> Result { - let mut final_dims = indexes.dims().to_vec(); - final_dims.push(self.hidden_size); - let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; - let values = values.reshape(final_dims)?; - Ok(values) - } +fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Embedding::new(embeddings, hidden_size)) } fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { @@ -249,19 +226,19 @@ struct BertEmbeddings { impl BertEmbeddings { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let word_embeddings = Embedding::load( + let word_embeddings = embedding( config.vocab_size, config.hidden_size, &format!("{p}.word_embeddings"), vb, )?; - let position_embeddings = Embedding::load( + let position_embeddings = embedding( config.max_position_embeddings, config.hidden_size, &format!("{p}.position_embeddings"), vb, )?; - let token_type_embeddings = Embedding::load( + let token_type_embeddings = embedding( config.type_vocab_size, config.hidden_size, &format!("{p}.token_type_embeddings"), diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index e22b7b47..9283f229 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -1,6 +1,6 @@ use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D}; -use candle_nn::{LayerNorm, Linear}; +use candle_nn::{Embedding, LayerNorm, Linear}; use std::collections::HashMap; const MAX_SEQ_LEN: usize = 5000; @@ -108,33 +108,9 @@ impl Dropout { } } -#[derive(Debug)] -struct Embedding { - embeddings: Tensor, - hidden_size: usize, -} - -impl Embedding { - fn new(embeddings: Tensor, hidden_size: usize) -> Self { - Self { - embeddings, - hidden_size, - } - } - - fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; - Ok(Self::new(embeddings, hidden_size)) - } - - fn forward(&self, indexes: &Tensor) -> Result { - let mut final_dims = indexes.dims().to_vec(); - final_dims.push(self.hidden_size); - let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; - let values = values.reshape(final_dims)?; - Ok(values) - } +fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Embedding::new(embeddings, hidden_size)) } // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py @@ -563,7 +539,7 @@ impl Falcon { } pub fn load(vb: &VarBuilder, cfg: Config) -> Result { - let word_embeddings = Embedding::load( + let word_embeddings = embedding( cfg.vocab_size, cfg.hidden_size, "transformer.word_embeddings", diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 6ed4335a..6ef15c21 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,4 +1,6 @@ -use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder}; +use crate::nn::{ + embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder, +}; use crate::{encodec_model, t5_model}; use anyhow::Result; use candle::{DType, Device, Tensor, D}; @@ -283,7 +285,7 @@ impl MusicgenDecoder { }; let embed_dim = cfg.vocab_size + 1; let embed_tokens = (0..cfg.num_codebooks) - .map(|i| Embedding::load(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb)) + .map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb)) .collect::>>()?; let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?; let layers = (0..cfg.num_hidden_layers) diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 19f35586..1a2be3d0 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -113,33 +113,16 @@ impl Dropout { } } -#[derive(Debug)] -pub struct Embedding { - embeddings: Tensor, +pub type Embedding = candle_nn::Embedding; + +pub fn embedding( + vocab_size: usize, hidden_size: usize, -} - -impl Embedding { - pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { - Self { - embeddings, - hidden_size, - } - } - - pub fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; - Ok(Self::new(embeddings, hidden_size)) - } - - pub fn forward(&self, indexes: &Tensor) -> Result { - let mut final_dims = indexes.dims().to_vec(); - final_dims.push(self.hidden_size); - let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; - let values = values.reshape(final_dims)?; - Ok(values) - } + p: &str, + vb: &VarBuilder, +) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Embedding::new(embeddings, hidden_size)) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -197,17 +180,4 @@ impl Conv1D { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HiddenAct { - Gelu, - Relu, -} - -impl HiddenAct { - pub fn forward(&self, xs: &Tensor) -> candle::Result { - match self { - Self::Gelu => xs.gelu(), - Self::Relu => xs.relu(), - } - } -} +pub type HiddenAct = candle_nn::Activation; diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 9e37fbd8..c904d67c 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -1,7 +1,7 @@ // T5 Text Encoder // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; +use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; use anyhow::Result; use candle::Tensor; @@ -159,7 +159,7 @@ impl T5Attention { let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?; let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?; let relative_attention_bias = if h { - let emb = Embedding::load( + let emb = embedding( cfg.relative_attention_num_buckets, cfg.num_heads, &format!("{p}.relative_attention_bias"), @@ -281,7 +281,7 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?; + let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?; let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?; Ok(Self { shared, encoder }) } diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 4c4ff4e7..f74eb8bf 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -2,7 +2,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; -use candle_nn::{LayerNorm, Linear}; +use candle_nn::{Embedding, LayerNorm, Linear}; use serde::Deserialize; use std::collections::HashMap; @@ -63,21 +63,6 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum HiddenAct { - Gelu, - Relu, -} - -impl HiddenAct { - fn forward(&self, xs: &Tensor) -> candle::Result { - match self { - Self::Gelu => xs.gelu(), - Self::Relu => xs.relu(), - } - } -} - // The names in comments correspond to the original implementation: // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -111,32 +96,9 @@ impl Config { } } -struct Embedding { - embeddings: Tensor, - hidden_size: usize, -} - -impl Embedding { - fn new(embeddings: Tensor, hidden_size: usize) -> Self { - Self { - embeddings, - hidden_size, - } - } - - fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; - Ok(Self::new(embeddings, hidden_size)) - } - - fn forward(&self, indexes: &Tensor) -> Result { - let mut final_dims = indexes.dims().to_vec(); - final_dims.push(self.hidden_size); - let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; - let values = values.reshape(final_dims)?; - Ok(values) - } +fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Embedding::new(embeddings, hidden_size)) } fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { @@ -449,8 +411,7 @@ impl TextDecoder { let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; - let token_embedding = - Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?; + let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?; let positional_embedding = vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?; let blocks = (0..cfg.decoder_layers) @@ -483,7 +444,10 @@ impl TextDecoder { x = block.forward(&x, Some(xa), Some(&self.mask))?; } let x = self.ln.forward(&x)?; - let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?; + let w = self + .token_embedding + .embeddings() + .broadcast_left(x_dims[0])?; let logits = x.matmul(&w.t()?)?; Ok(logits) } diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs new file mode 100644 index 00000000..9554e68a --- /dev/null +++ b/candle-nn/src/activation.rs @@ -0,0 +1,18 @@ +use candle::Tensor; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Activation { + Gelu, + Relu, + Elu(f64), +} + +impl Activation { + pub fn forward(&self, xs: &Tensor) -> candle::Result { + match self { + Self::Gelu => xs.gelu(), + Self::Relu => xs.relu(), + &Self::Elu(alpha) => xs.elu(alpha), + } + } +} diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs new file mode 100644 index 00000000..deeba01e --- /dev/null +++ b/candle-nn/src/embedding.rs @@ -0,0 +1,29 @@ +use candle::{Result, Tensor}; + +#[derive(Debug)] +pub struct Embedding { + embeddings: Tensor, + hidden_size: usize, +} + +impl Embedding { + pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { + Self { + embeddings, + hidden_size, + } + } + + pub fn embeddings(&self) -> &Tensor { + &self.embeddings + } + + pub fn forward(&self, indexes: &Tensor) -> Result { + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.hidden_size); + let indexes = indexes.flatten_all()?; + let values = Tensor::embedding(&indexes, &self.embeddings)?; + let values = values.reshape(final_dims)?; + Ok(values) + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 09fe65b9..45aa5e92 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,5 +1,11 @@ +// For now this crate shares its error type with candle-core. We may introduce some separate +// error type if needed or add some specialized cases on the candle-core side. +mod activation; +mod embedding; mod layer_norm; mod linear; +pub use activation::Activation; +pub use embedding::Embedding; pub use layer_norm::LayerNorm; pub use linear::Linear;