diff --git a/Cargo.toml b/Cargo.toml index c8fa56f2..d52bf3e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "candle-examples", "candle-kernels", "candle-hub", + "candle-nn", "candle-pyo3", ] diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 943f1953..27e85eee 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -14,7 +14,8 @@ readme = "README.md" blas = { version = "0.22.0", optional = true } byteorder = "1.4.3" candle-kernels = { path = "../candle-kernels", optional = true } -# cudarc = { version = "0.9.12", optional = true, features = ["f16"] } +# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes +# cudarc = { version = "0.9.13", optional = true, features = ["f16"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] } # TODO: Switch back to the official gemm implementation once something similar to # https://github.com/sarah-ek/gemm/pull/8 is available. diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 77441374..889e0051 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -12,6 +12,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", default-features=false } +candle-nn = { path = "../candle-nn", default-features=false } serde = { version = "1.0.166", features = ["derive"] } serde_json = "1.0.99" num-traits = "0.2.15" @@ -27,5 +28,5 @@ wav = "1.0.0" [features] default = ["cuda"] -cuda = ["candle/cuda"] +cuda = ["candle/cuda", "candle-nn/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 8b292f92..b2f92bbc 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -6,6 +6,7 @@ extern crate intel_mkl_src; use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; +use candle_nn::{LayerNorm, Linear}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; @@ -194,29 +195,10 @@ impl Embedding { } } -struct Linear { - weight: Tensor, - bias: Tensor, -} - -impl Linear { - fn new(weight: Tensor, bias: Tensor) -> Self { - Self { weight, bias } - } - - fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; - Ok(Self::new(weight, bias)) - } - - fn forward(&self, x: &Tensor) -> Result { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.matmul(&w)?; - let x = x.broadcast_add(&self.bias)?; - Ok(x) - } +fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = vb.get(size2, &format!("{p}.bias"))?; + Ok(Linear::new(weight, Some(bias))) } struct Dropout { @@ -234,49 +216,24 @@ impl Dropout { } } -// This layer norm version handles both weight and bias so removes the mean. -struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, -} - -impl LayerNorm { - fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - Self { weight, bias, eps } - } - - fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { - (weight, bias) - } else { - return Err(err.into()); - } +fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); } - }; - Ok(Self { weight, bias, eps }) - } - - fn forward(&self, x: &Tensor) -> Result { - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 @@ -310,7 +267,7 @@ impl BertEmbeddings { &format!("{p}.token_type_embeddings"), vb, )?; - let layer_norm = LayerNorm::load( + let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, &format!("{p}.LayerNorm"), @@ -362,9 +319,9 @@ impl BertSelfAttention { let all_head_size = config.num_attention_heads * attention_head_size; let dropout = Dropout::new(config.hidden_dropout_prob); let hidden_size = config.hidden_size; - let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?; - let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?; - let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?; + let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?; + let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?; + let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?; Ok(Self { query, key, @@ -414,13 +371,13 @@ struct BertSelfOutput { impl BertSelfOutput { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = Linear::load( + let dense = linear( config.hidden_size, config.hidden_size, &format!("{p}.dense"), vb, )?; - let layer_norm = LayerNorm::load( + let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, &format!("{p}.LayerNorm"), @@ -437,7 +394,7 @@ impl BertSelfOutput { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; - self.layer_norm.forward(&(hidden_states + input_tensor)?) + Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?) } } @@ -472,7 +429,7 @@ struct BertIntermediate { impl BertIntermediate { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = Linear::load( + let dense = linear( config.hidden_size, config.intermediate_size, &format!("{p}.dense"), @@ -500,13 +457,13 @@ struct BertOutput { impl BertOutput { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = Linear::load( + let dense = linear( config.intermediate_size, config.hidden_size, &format!("{p}.dense"), vb, )?; - let layer_norm = LayerNorm::load( + let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, &format!("{p}.LayerNorm"), @@ -523,7 +480,7 @@ impl BertOutput { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; - self.layer_norm.forward(&(hidden_states + input_tensor)?) + Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?) } } diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index e7c53e50..e22b7b47 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -1,5 +1,6 @@ use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D}; +use candle_nn::{LayerNorm, Linear}; use std::collections::HashMap; const MAX_SEQ_LEN: usize = 5000; @@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug)] -struct Linear { - weight: Tensor, - bias: Option, +fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = if bias { + Some(vb.get(size2, &format!("{p}.bias"))?) + } else { + None + }; + Ok(Linear::new(weight, bias)) } -impl Linear { - fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) - } else { - None - }; - Ok(Self { weight, bias }) - } - - fn forward(&self, x: &Tensor) -> candle::Result { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -#[derive(Debug)] -struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, -} - -impl LayerNorm { - fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - Self { weight, bias, eps } - } - - fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { - (weight, bias) - } else { - return Err(err.into()); - } +fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); } - }; - Ok(Self { weight, bias, eps }) - } - - fn forward(&self, x: &Tensor) -> Result { - let dtype = x.dtype(); - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; - let x = x.to_dtype(DType::F32)?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) } #[derive(Debug)] @@ -378,14 +333,14 @@ impl FalconAttention { } else { 3 * hidden_size }; - let query_key_value = Linear::load( + let query_key_value = linear( hidden_size, qkv_out_dim, cfg.bias, &format!("{p}.query_key_value"), vb, )?; - let dense = Linear::load( + let dense = linear( hidden_size, hidden_size, cfg.bias, @@ -497,8 +452,8 @@ impl FalconMlp { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; let b = cfg.bias; - let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; - let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; + let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; + let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; let dropout = Dropout::new(cfg.hidden_dropout); Ok(Self { dense_h_to_4h, @@ -526,7 +481,7 @@ struct FalconDecoderLayer { impl FalconDecoderLayer { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?; - let inp_layernorm = LayerNorm::load( + let inp_layernorm = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, &format!("{p}.input_layernorm"), @@ -536,7 +491,7 @@ impl FalconDecoderLayer { let post_attention_layernorm = if cfg.parallel_attn { None } else { - let ln = LayerNorm::load( + let ln = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, &format!("{p}.post_attention_layernorm"), @@ -617,13 +572,13 @@ impl Falcon { let blocks = (0..cfg.num_hidden_layers) .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) .collect::>>()?; - let ln_f = LayerNorm::load( + let ln_f = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, "transformer.ln_f", vb, )?; - let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; Ok(Self { word_embeddings, blocks, diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 62fd3b63..6ed4335a 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,4 +1,4 @@ -use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder}; +use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder}; use crate::{encodec_model, t5_model}; use anyhow::Result; use candle::{DType, Device, Tensor, D}; @@ -146,10 +146,10 @@ impl MusicgenAttention { let h = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let head_dim = h / num_heads; - let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?; - let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?; - let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?; - let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?; + let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?; + let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?; + let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?; + let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?; Ok(Self { scaling: 1. / (head_dim as f64).sqrt(), is_decoder: true, @@ -213,14 +213,13 @@ impl MusicgenDecoderLayer { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?; - let self_attn_layer_norm = - LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?; + let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?; let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?; let encoder_attn_layer_norm = - LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?; - let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?; - let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?; - let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?; + layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?; + let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?; + let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?; + let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?; Ok(Self { self_attn, self_attn_layer_norm, @@ -290,7 +289,7 @@ impl MusicgenDecoder { let layers = (0..cfg.num_hidden_layers) .map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg)) .collect::>>()?; - let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?; + let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?; Ok(Self { embed_tokens, embed_positions, @@ -341,7 +340,7 @@ impl MusicgenForCausalLM { let h = cfg.hidden_size; let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?; let lm_heads = (0..cfg.num_codebooks) - .map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb)) + .map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb)) .collect::>>()?; Ok(Self { decoder, diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 25b4901c..19f35586 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -63,80 +63,38 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug)] -pub struct Linear { - weight: Tensor, - bias: Option, +pub type Linear = candle_nn::Linear; + +pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = if bias { + Some(vb.get(size2, &format!("{p}.bias"))?) + } else { + None + }; + Ok(Linear::new(weight, bias)) } -impl Linear { - pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) - } else { - None - }; - Ok(Self { weight, bias }) - } +pub type LayerNorm = candle_nn::LayerNorm; - pub fn forward(&self, x: &Tensor) -> candle::Result { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -#[derive(Debug)] -pub struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, -} - -impl LayerNorm { - pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - Self { weight, bias, eps } - } - - pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { - (weight, bias) - } else { - return Err(err.into()); - } +pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); } - }; - Ok(Self { weight, bias, eps }) - } - - pub fn forward(&self, x: &Tensor) -> Result { - let dtype = x.dtype(); - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; - let x = x.to_dtype(DType::F32)?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) } #[derive(Debug)] diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index b3f682a7..9e37fbd8 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -1,7 +1,7 @@ // T5 Text Encoder // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder}; +use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; use anyhow::Result; use candle::Tensor; @@ -104,8 +104,8 @@ struct T5DenseActDense { impl T5DenseActDense { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?; - let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?; + let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?; + let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?; let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { wi, @@ -154,10 +154,10 @@ struct T5Attention { impl T5Attention { fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?; - let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?; - let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?; - let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?; + let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?; + let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?; + let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?; + let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?; let relative_attention_bias = if h { let emb = Embedding::load( cfg.relative_attention_num_buckets, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index e589e231..4c4ff4e7 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -2,6 +2,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use candle_nn::{LayerNorm, Linear}; use serde::Deserialize; use std::collections::HashMap; @@ -138,35 +139,15 @@ impl Embedding { } } -struct Linear { - weight: Tensor, - bias: Option, +fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = vb.get(size2, &format!("{p}.bias"))?; + Ok(Linear::new(weight, Some(bias))) } -impl Linear { - fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; - Ok(Self { - weight, - bias: Some(bias), - }) - } - - fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - Ok(Self { weight, bias: None }) - } - - fn forward(&self, x: &Tensor) -> candle::Result { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } +fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + Ok(Linear::new(weight, None)) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -258,35 +239,10 @@ impl Dropout { } } -// This layer norm version handles both weight and bias so removes the mean. -struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, -} - -impl LayerNorm { - fn load(size: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(size, &format!("{p}.weight"))?; - let bias = vb.get(size, &format!("{p}.bias"))?; - Ok(Self { - weight, - bias, - eps: 1e-5, - }) - } - - fn forward(&self, x: &Tensor) -> Result { - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } +fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get(size, &format!("{p}.weight"))?; + let bias = vb.get(size, &format!("{p}.bias"))?; + Ok(LayerNorm::new(weight, bias, 1e-5)) } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 @@ -300,10 +256,10 @@ struct MultiHeadAttention { impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result { - let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?; - let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?; - let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; - let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?; + let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?; + let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?; + let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; + let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?; Ok(Self { query, key, @@ -364,20 +320,19 @@ struct ResidualAttentionBlock { impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result { let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?; - let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; + let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; let cross_attn = if ca { let cross_attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?; - let cross_attn_ln = - LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?; + let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; - let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?; - let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?; - let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?; + let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?; + let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?; + let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?; Ok(Self { attn, attn_ln, @@ -456,7 +411,7 @@ impl AudioEncoder { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb) }) .collect::>>()?; - let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; + let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; Ok(Self { conv1, conv2, @@ -503,7 +458,7 @@ impl TextDecoder { ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb) }) .collect::>>()?; - let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; + let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml new file mode 100644 index 00000000..73a4954c --- /dev/null +++ b/candle-nn/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "candle-nn" +version = "0.1.0" +edition = "2021" + +description = "Minimalist ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../candle-core", default-features=false } +thiserror = "1" +intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]} + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } + +[features] +default = ["cuda"] +cuda = ["candle/cuda"] +mkl = ["dep:intel-mkl-src", "candle/mkl"] diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs new file mode 100644 index 00000000..0b208c49 --- /dev/null +++ b/candle-nn/src/layer_norm.rs @@ -0,0 +1,34 @@ +use candle::{DType, Result, Tensor}; + +// This layer norm version handles both weight and bias so removes the mean. +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + Self { weight, bias, eps } + } + + pub fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum(&[2])? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs new file mode 100644 index 00000000..09fe65b9 --- /dev/null +++ b/candle-nn/src/lib.rs @@ -0,0 +1,5 @@ +mod layer_norm; +mod linear; + +pub use layer_norm::LayerNorm; +pub use linear::Linear; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs new file mode 100644 index 00000000..2e65ca2d --- /dev/null +++ b/candle-nn/src/linear.rs @@ -0,0 +1,25 @@ +use candle::Tensor; + +#[derive(Debug)] +pub struct Linear { + weight: Tensor, + bias: Option, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option) -> Self { + Self { weight, bias } + } + + pub fn forward(&self, x: &Tensor) -> candle::Result { + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +}