Sketch the candle-nn crate. (#115)

* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
This commit is contained in:
Laurent Mazare
2023-07-10 08:50:09 +01:00
committed by GitHub
parent bc3be6f9b0
commit 9ce0f1c010
13 changed files with 230 additions and 315 deletions

View File

@ -4,6 +4,7 @@ members = [
"candle-examples", "candle-examples",
"candle-kernels", "candle-kernels",
"candle-hub", "candle-hub",
"candle-nn",
"candle-pyo3", "candle-pyo3",
] ]

View File

@ -14,7 +14,8 @@ readme = "README.md"
blas = { version = "0.22.0", optional = true } blas = { version = "0.22.0", optional = true }
byteorder = "1.4.3" byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true } candle-kernels = { path = "../candle-kernels", optional = true }
# cudarc = { version = "0.9.12", optional = true, features = ["f16"] } # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
# TODO: Switch back to the official gemm implementation once something similar to # TODO: Switch back to the official gemm implementation once something similar to
# https://github.com/sarah-ek/gemm/pull/8 is available. # https://github.com/sarah-ek/gemm/pull/8 is available.

View File

@ -12,6 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", default-features=false } candle = { path = "../candle-core", default-features=false }
candle-nn = { path = "../candle-nn", default-features=false }
serde = { version = "1.0.166", features = ["derive"] } serde = { version = "1.0.166", features = ["derive"] }
serde_json = "1.0.99" serde_json = "1.0.99"
num-traits = "0.2.15" num-traits = "0.2.15"
@ -27,5 +28,5 @@ wav = "1.0.0"
[features] [features]
default = ["cuda"] default = ["cuda"]
cuda = ["candle/cuda"] cuda = ["candle/cuda", "candle-nn/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl"]

View File

@ -6,6 +6,7 @@ extern crate intel_mkl_src;
use anyhow::{anyhow, Error as E, Result}; use anyhow::{anyhow, Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
use candle_nn::{LayerNorm, Linear};
use clap::Parser; use clap::Parser;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
@ -194,29 +195,10 @@ impl Embedding {
} }
} }
struct Linear { fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
weight: Tensor, let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
bias: Tensor, let bias = vb.get(size2, &format!("{p}.bias"))?;
} Ok(Linear::new(weight, Some(bias)))
impl Linear {
fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight, bias }
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size2, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias))
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
let x = x.matmul(&w)?;
let x = x.broadcast_add(&self.bias)?;
Ok(x)
}
} }
struct Dropout { struct Dropout {
@ -234,49 +216,24 @@ impl Dropout {
} }
} }
// This layer norm version handles both weight and bias so removes the mean. fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
struct LayerNorm { let (weight, bias) = match (
weight: Tensor, vb.get(size, &format!("{p}.weight")),
bias: Tensor, vb.get(size, &format!("{p}.bias")),
eps: f64, ) {
} (Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
impl LayerNorm { if let (Ok(weight), Ok(bias)) = (
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { vb.get(size, &format!("{p}.gamma")),
Self { weight, bias, eps } vb.get(size, &format!("{p}.beta")),
} ) {
(weight, bias)
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> { } else {
let (weight, bias) = match ( return Err(err.into());
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
(weight, bias)
} else {
return Err(err.into());
}
} }
}; }
Ok(Self { weight, bias, eps }) };
} Ok(LayerNorm::new(weight, bias, eps))
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
@ -310,7 +267,7 @@ impl BertEmbeddings {
&format!("{p}.token_type_embeddings"), &format!("{p}.token_type_embeddings"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load( let layer_norm = layer_norm(
config.hidden_size, config.hidden_size,
config.layer_norm_eps, config.layer_norm_eps,
&format!("{p}.LayerNorm"), &format!("{p}.LayerNorm"),
@ -362,9 +319,9 @@ impl BertSelfAttention {
let all_head_size = config.num_attention_heads * attention_head_size; let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size; let hidden_size = config.hidden_size;
let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?; let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?; let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?; let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
Ok(Self { Ok(Self {
query, query,
key, key,
@ -414,13 +371,13 @@ struct BertSelfOutput {
impl BertSelfOutput { impl BertSelfOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load( let dense = linear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
&format!("{p}.dense"), &format!("{p}.dense"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load( let layer_norm = layer_norm(
config.hidden_size, config.hidden_size,
config.layer_norm_eps, config.layer_norm_eps,
&format!("{p}.LayerNorm"), &format!("{p}.LayerNorm"),
@ -437,7 +394,7 @@ impl BertSelfOutput {
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?) Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
} }
} }
@ -472,7 +429,7 @@ struct BertIntermediate {
impl BertIntermediate { impl BertIntermediate {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load( let dense = linear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
&format!("{p}.dense"), &format!("{p}.dense"),
@ -500,13 +457,13 @@ struct BertOutput {
impl BertOutput { impl BertOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load( let dense = linear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
&format!("{p}.dense"), &format!("{p}.dense"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load( let layer_norm = layer_norm(
config.hidden_size, config.hidden_size,
config.layer_norm_eps, config.layer_norm_eps,
&format!("{p}.LayerNorm"), &format!("{p}.LayerNorm"),
@ -523,7 +480,7 @@ impl BertOutput {
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?) Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
} }
} }

View File

@ -1,5 +1,6 @@
use anyhow::Result; use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
use candle_nn::{LayerNorm, Linear};
use std::collections::HashMap; use std::collections::HashMap;
const MAX_SEQ_LEN: usize = 5000; const MAX_SEQ_LEN: usize = 5000;
@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> {
} }
} }
#[derive(Debug)] fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
struct Linear { let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
weight: Tensor, let bias = if bias {
bias: Option<Tensor>, Some(vb.get(size2, &format!("{p}.bias"))?)
} else {
None
};
Ok(Linear::new(weight, bias))
} }
impl Linear { fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> { let (weight, bias) = match (
let weight = vb.get((size2, size1), &format!("{p}.weight"))?; vb.get(size, &format!("{p}.weight")),
let bias = if bias { vb.get(size, &format!("{p}.bias")),
Some(vb.get(size2, &format!("{p}.bias"))?) ) {
} else { (Ok(weight), Ok(bias)) => (weight, bias),
None (Err(err), _) | (_, Err(err)) => {
}; if let (Ok(weight), Ok(bias)) = (
Ok(Self { weight, bias }) vb.get(size, &format!("{p}.gamma")),
} vb.get(size, &format!("{p}.beta")),
) {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { (weight, bias)
let (bsize, _, _) = x.shape().r3()?; } else {
let w = self.weight.broadcast_left(bsize)?.t()?; return Err(err.into());
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
(weight, bias)
} else {
return Err(err.into());
}
} }
}; }
Ok(Self { weight, bias, eps }) };
} Ok(LayerNorm::new(weight, bias, eps))
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let dtype = x.dtype();
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(DType::F32)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -378,14 +333,14 @@ impl FalconAttention {
} else { } else {
3 * hidden_size 3 * hidden_size
}; };
let query_key_value = Linear::load( let query_key_value = linear(
hidden_size, hidden_size,
qkv_out_dim, qkv_out_dim,
cfg.bias, cfg.bias,
&format!("{p}.query_key_value"), &format!("{p}.query_key_value"),
vb, vb,
)?; )?;
let dense = Linear::load( let dense = linear(
hidden_size, hidden_size,
hidden_size, hidden_size,
cfg.bias, cfg.bias,
@ -497,8 +452,8 @@ impl FalconMlp {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size; let h = cfg.hidden_size;
let b = cfg.bias; let b = cfg.bias;
let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
let dropout = Dropout::new(cfg.hidden_dropout); let dropout = Dropout::new(cfg.hidden_dropout);
Ok(Self { Ok(Self {
dense_h_to_4h, dense_h_to_4h,
@ -526,7 +481,7 @@ struct FalconDecoderLayer {
impl FalconDecoderLayer { impl FalconDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?; let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
let inp_layernorm = LayerNorm::load( let inp_layernorm = layer_norm(
cfg.hidden_size, cfg.hidden_size,
cfg.layer_norm_epsilon, cfg.layer_norm_epsilon,
&format!("{p}.input_layernorm"), &format!("{p}.input_layernorm"),
@ -536,7 +491,7 @@ impl FalconDecoderLayer {
let post_attention_layernorm = if cfg.parallel_attn { let post_attention_layernorm = if cfg.parallel_attn {
None None
} else { } else {
let ln = LayerNorm::load( let ln = layer_norm(
cfg.hidden_size, cfg.hidden_size,
cfg.layer_norm_epsilon, cfg.layer_norm_epsilon,
&format!("{p}.post_attention_layernorm"), &format!("{p}.post_attention_layernorm"),
@ -617,13 +572,13 @@ impl Falcon {
let blocks = (0..cfg.num_hidden_layers) let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln_f = LayerNorm::load( let ln_f = layer_norm(
cfg.hidden_size, cfg.hidden_size,
cfg.layer_norm_epsilon, cfg.layer_norm_epsilon,
"transformer.ln_f", "transformer.ln_f",
vb, vb,
)?; )?;
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
Ok(Self { Ok(Self {
word_embeddings, word_embeddings,
blocks, blocks,

View File

@ -1,4 +1,4 @@
use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder}; use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
use crate::{encodec_model, t5_model}; use crate::{encodec_model, t5_model};
use anyhow::Result; use anyhow::Result;
use candle::{DType, Device, Tensor, D}; use candle::{DType, Device, Tensor, D};
@ -146,10 +146,10 @@ impl MusicgenAttention {
let h = cfg.hidden_size; let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads; let head_dim = h / num_heads;
let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?; let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?; let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?; let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?; let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
Ok(Self { Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(), scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true, is_decoder: true,
@ -213,14 +213,13 @@ impl MusicgenDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size; let h = cfg.hidden_size;
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?; let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
let self_attn_layer_norm = let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?; let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
let encoder_attn_layer_norm = let encoder_attn_layer_norm =
LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?; layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?; let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?; let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?; let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
Ok(Self { Ok(Self {
self_attn, self_attn,
self_attn_layer_norm, self_attn_layer_norm,
@ -290,7 +289,7 @@ impl MusicgenDecoder {
let layers = (0..cfg.num_hidden_layers) let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg)) .map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?; let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
Ok(Self { Ok(Self {
embed_tokens, embed_tokens,
embed_positions, embed_positions,
@ -341,7 +340,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size; let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?; let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
let lm_heads = (0..cfg.num_codebooks) let lm_heads = (0..cfg.num_codebooks)
.map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb)) .map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self { Ok(Self {
decoder, decoder,

View File

@ -63,80 +63,38 @@ impl<'a> VarBuilder<'a> {
} }
} }
#[derive(Debug)] pub type Linear = candle_nn::Linear;
pub struct Linear {
weight: Tensor, pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
bias: Option<Tensor>, let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = if bias {
Some(vb.get(size2, &format!("{p}.bias"))?)
} else {
None
};
Ok(Linear::new(weight, bias))
} }
impl Linear { pub type LayerNorm = candle_nn::LayerNorm;
pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = if bias {
Some(vb.get(size2, &format!("{p}.bias"))?)
} else {
None
};
Ok(Self { weight, bias })
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
let (bsize, _, _) = x.shape().r3()?; let (weight, bias) = match (
let w = self.weight.broadcast_left(bsize)?.t()?; vb.get(size, &format!("{p}.weight")),
let x = x.matmul(&w)?; vb.get(size, &format!("{p}.bias")),
match &self.bias { ) {
None => Ok(x), (Ok(weight), Ok(bias)) => (weight, bias),
Some(bias) => x.broadcast_add(bias), (Err(err), _) | (_, Err(err)) => {
} if let (Ok(weight), Ok(bias)) = (
} vb.get(size, &format!("{p}.gamma")),
} vb.get(size, &format!("{p}.beta")),
) {
#[derive(Debug)] (weight, bias)
pub struct LayerNorm { } else {
weight: Tensor, return Err(err.into());
bias: Tensor,
eps: f64,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
(weight, bias)
} else {
return Err(err.into());
}
} }
}; }
Ok(Self { weight, bias, eps }) };
} Ok(LayerNorm::new(weight, bias, eps))
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let dtype = x.dtype();
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(DType::F32)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -1,7 +1,7 @@
// T5 Text Encoder // T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder}; use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result; use anyhow::Result;
use candle::Tensor; use candle::Tensor;
@ -104,8 +104,8 @@ struct T5DenseActDense {
impl T5DenseActDense { impl T5DenseActDense {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?; let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?; let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
let dropout = Dropout::new(cfg.dropout_rate); let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self { Ok(Self {
wi, wi,
@ -154,10 +154,10 @@ struct T5Attention {
impl T5Attention { impl T5Attention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv; let inner_dim = cfg.num_heads * cfg.d_kv;
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?; let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?; let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?; let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?; let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let relative_attention_bias = if h { let relative_attention_bias = if h {
let emb = Embedding::load( let emb = Embedding::load(
cfg.relative_attention_num_buckets, cfg.relative_attention_num_buckets,

View File

@ -2,6 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1. // back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result; use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_nn::{LayerNorm, Linear};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
@ -138,35 +139,15 @@ impl Embedding {
} }
} }
struct Linear { fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
weight: Tensor, let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
bias: Option<Tensor>, let bias = vb.get(size2, &format!("{p}.bias"))?;
Ok(Linear::new(weight, Some(bias)))
} }
impl Linear { fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> { let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let weight = vb.get((size2, size1), &format!("{p}.weight"))?; Ok(Linear::new(weight, None))
let bias = vb.get(size2, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
})
}
fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
Ok(Self { weight, bias: None })
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -258,35 +239,10 @@ impl Dropout {
} }
} }
// This layer norm version handles both weight and bias so removes the mean. fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
struct LayerNorm { let weight = vb.get(size, &format!("{p}.weight"))?;
weight: Tensor, let bias = vb.get(size, &format!("{p}.bias"))?;
bias: Tensor, Ok(LayerNorm::new(weight, bias, 1e-5))
eps: f64,
}
impl LayerNorm {
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias,
eps: 1e-5,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
} }
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
@ -300,10 +256,10 @@ struct MultiHeadAttention {
impl MultiHeadAttention { impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?; let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?; let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?; let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
Ok(Self { Ok(Self {
query, query,
key, key,
@ -364,20 +320,19 @@ struct ResidualAttentionBlock {
impl ResidualAttentionBlock { impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?; let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
let cross_attn = if ca { let cross_attn = if ca {
let cross_attn = let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?; MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
let cross_attn_ln = let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
Some((cross_attn, cross_attn_ln)) Some((cross_attn, cross_attn_ln))
} else { } else {
None None
}; };
let n_mlp = n_state * 4; let n_mlp = n_state * 4;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?; let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?; let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?; let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
Ok(Self { Ok(Self {
attn, attn,
attn_ln, attn_ln,
@ -456,7 +411,7 @@ impl AudioEncoder {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb) ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
Ok(Self { Ok(Self {
conv1, conv1,
conv2, conv2,
@ -503,7 +458,7 @@ impl TextDecoder {
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb) ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
let mask: Vec<_> = (0..n_ctx) let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect(); .collect();

24
candle-nn/Cargo.toml Normal file
View File

@ -0,0 +1,24 @@
[package]
name = "candle-nn"
version = "0.1.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", default-features=false }
thiserror = "1"
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
[features]
default = ["cuda"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]

View File

@ -0,0 +1,34 @@
use candle::{DType, Result, Tensor};
// This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}

5
candle-nn/src/lib.rs Normal file
View File

@ -0,0 +1,5 @@
mod layer_norm;
mod linear;
pub use layer_norm::LayerNorm;
pub use linear::Linear;

25
candle-nn/src/linear.rs Normal file
View File

@ -0,0 +1,25 @@
use candle::Tensor;
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}