mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
[nn] Move the Embedding and Activation parts. (#116)
* Share the Embedding and Activation parts. * Tweak some activations.
This commit is contained in:
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -167,32 +167,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
@ -249,19 +226,19 @@ struct BertEmbeddings {
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = Embedding::load(
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.word_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let position_embeddings = Embedding::load(
|
||||
let position_embeddings = embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
&format!("{p}.position_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let token_type_embeddings = Embedding::load(
|
||||
let token_type_embeddings = embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.token_type_embeddings"),
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
@ -108,33 +108,9 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
@ -563,7 +539,7 @@ impl Falcon {
|
||||
}
|
||||
|
||||
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let word_embeddings = Embedding::load(
|
||||
let word_embeddings = embedding(
|
||||
cfg.vocab_size,
|
||||
cfg.hidden_size,
|
||||
"transformer.word_embeddings",
|
||||
|
@ -1,4 +1,6 @@
|
||||
use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
||||
use crate::nn::{
|
||||
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
|
||||
};
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
@ -283,7 +285,7 @@ impl MusicgenDecoder {
|
||||
};
|
||||
let embed_dim = cfg.vocab_size + 1;
|
||||
let embed_tokens = (0..cfg.num_codebooks)
|
||||
.map(|i| Embedding::load(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
|
||||
.map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
|
@ -113,33 +113,16 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
pub type Embedding = candle_nn::Embedding;
|
||||
|
||||
pub fn embedding(
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -197,17 +180,4 @@ impl Conv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
impl HiddenAct {
|
||||
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type HiddenAct = candle_nn::Activation;
|
||||
|
@ -1,7 +1,7 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
@ -159,7 +159,7 @@ impl T5Attention {
|
||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = Embedding::load(
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
&format!("{p}.relative_attention_bias"),
|
||||
@ -281,7 +281,7 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -63,21 +63,6 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
impl HiddenAct {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
@ -111,32 +96,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
@ -449,8 +411,7 @@ impl TextDecoder {
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
let token_embedding =
|
||||
Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
|
||||
let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
|
||||
let positional_embedding =
|
||||
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
@ -483,7 +444,10 @@ impl TextDecoder {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||
let w = self
|
||||
.token_embedding
|
||||
.embeddings()
|
||||
.broadcast_left(x_dims[0])?;
|
||||
let logits = x.matmul(&w.t()?)?;
|
||||
Ok(logits)
|
||||
}
|
||||
|
18
candle-nn/src/activation.rs
Normal file
18
candle-nn/src/activation.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use candle::Tensor;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Activation {
|
||||
Gelu,
|
||||
Relu,
|
||||
Elu(f64),
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
}
|
||||
}
|
||||
}
|
29
candle-nn/src/embedding.rs
Normal file
29
candle-nn/src/embedding.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
&self.embeddings
|
||||
}
|
||||
|
||||
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
@ -1,5 +1,11 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
mod activation;
|
||||
mod embedding;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
|
||||
pub use activation::Activation;
|
||||
pub use embedding::Embedding;
|
||||
pub use layer_norm::LayerNorm;
|
||||
pub use linear::Linear;
|
||||
|
Reference in New Issue
Block a user