diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 51c524f5..810f2803 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,6 +1,6 @@ use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; @@ -112,11 +112,6 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - struct Dropout { #[allow(dead_code)] pr: f64, diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index c4a2d1db..e69f08c8 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { let weight = vb.get((size2, size1), "weight")?; @@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 6ede136a..ef5a92fc 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { Ok(LayerNorm::new(weight, bias, eps)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py #[derive(Debug)] pub struct Config { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7e8c8920..f003866a 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,6 +1,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -136,11 +136,6 @@ impl Cache { } } -fn embedding(cfg: &Config, vb: VarBuilder) -> Result { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Embedding::new(embeddings, cfg.hidden_size)) -} - struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -409,7 +404,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 25454ba6..ea2a59b9 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,12 +1,7 @@ use super::Config; use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} +use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; fn conv1d( in_channels: usize,