mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
@ -4,6 +4,7 @@ members = [
|
||||
"candle-examples",
|
||||
"candle-kernels",
|
||||
"candle-hub",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
]
|
||||
|
||||
|
@ -14,7 +14,8 @@ readme = "README.md"
|
||||
blas = { version = "0.22.0", optional = true }
|
||||
byteorder = "1.4.3"
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
# cudarc = { version = "0.9.12", optional = true, features = ["f16"] }
|
||||
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
||||
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
|
||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once something similar to
|
||||
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
||||
|
@ -12,6 +12,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", default-features=false }
|
||||
candle-nn = { path = "../candle-nn", default-features=false }
|
||||
serde = { version = "1.0.166", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
num-traits = "0.2.15"
|
||||
@ -27,5 +28,5 @@ wav = "1.0.0"
|
||||
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
cuda = ["candle/cuda"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
@ -6,6 +6,7 @@ extern crate intel_mkl_src;
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -194,29 +195,10 @@ impl Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
Ok(Self::new(weight, bias))
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
let x = x.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
@ -234,19 +216,7 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
@ -263,20 +233,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||
@ -310,7 +267,7 @@ impl BertEmbeddings {
|
||||
&format!("{p}.token_type_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
@ -362,9 +319,9 @@ impl BertSelfAttention {
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let hidden_size = config.hidden_size;
|
||||
let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
||||
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
||||
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
||||
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -414,13 +371,13 @@ struct BertSelfOutput {
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
let dense = linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
@ -437,7 +394,7 @@ impl BertSelfOutput {
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -472,7 +429,7 @@ struct BertIntermediate {
|
||||
|
||||
impl BertIntermediate {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
let dense = linear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
&format!("{p}.dense"),
|
||||
@ -500,13 +457,13 @@ struct BertOutput {
|
||||
|
||||
impl BertOutput {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
let dense = linear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
@ -523,7 +480,7 @@ impl BertOutput {
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
@ -61,47 +62,17 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { weight, bias })
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
@ -118,23 +89,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -378,14 +333,14 @@ impl FalconAttention {
|
||||
} else {
|
||||
3 * hidden_size
|
||||
};
|
||||
let query_key_value = Linear::load(
|
||||
let query_key_value = linear(
|
||||
hidden_size,
|
||||
qkv_out_dim,
|
||||
cfg.bias,
|
||||
&format!("{p}.query_key_value"),
|
||||
vb,
|
||||
)?;
|
||||
let dense = Linear::load(
|
||||
let dense = linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
cfg.bias,
|
||||
@ -497,8 +452,8 @@ impl FalconMlp {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let b = cfg.bias;
|
||||
let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||
let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||
let dropout = Dropout::new(cfg.hidden_dropout);
|
||||
Ok(Self {
|
||||
dense_h_to_4h,
|
||||
@ -526,7 +481,7 @@ struct FalconDecoderLayer {
|
||||
impl FalconDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
||||
let inp_layernorm = LayerNorm::load(
|
||||
let inp_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.input_layernorm"),
|
||||
@ -536,7 +491,7 @@ impl FalconDecoderLayer {
|
||||
let post_attention_layernorm = if cfg.parallel_attn {
|
||||
None
|
||||
} else {
|
||||
let ln = LayerNorm::load(
|
||||
let ln = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.post_attention_layernorm"),
|
||||
@ -617,13 +572,13 @@ impl Falcon {
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = LayerNorm::load(
|
||||
let ln_f = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
"transformer.ln_f",
|
||||
vb,
|
||||
)?;
|
||||
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
blocks,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
||||
use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
@ -146,10 +146,10 @@ impl MusicgenAttention {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = h / num_heads;
|
||||
let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?;
|
||||
let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?;
|
||||
let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?;
|
||||
let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?;
|
||||
let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
|
||||
let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
|
||||
let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
|
||||
let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
|
||||
Ok(Self {
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
is_decoder: true,
|
||||
@ -213,14 +213,13 @@ impl MusicgenDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
||||
let self_attn_layer_norm =
|
||||
LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
||||
let encoder_attn_layer_norm =
|
||||
LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
||||
let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
||||
let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
||||
layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
||||
let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
||||
let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
@ -290,7 +289,7 @@ impl MusicgenDecoder {
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
||||
let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
@ -341,7 +340,7 @@ impl MusicgenForCausalLM {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
||||
.map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
|
@ -63,47 +63,21 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
pub type Linear = candle_nn::Linear;
|
||||
|
||||
impl Linear {
|
||||
pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { weight, bias })
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type LayerNorm = candle_nn::LayerNorm;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
@ -120,23 +94,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -1,7 +1,7 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
@ -104,8 +104,8 @@ struct T5DenseActDense {
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
wi,
|
||||
@ -154,10 +154,10 @@ struct T5Attention {
|
||||
impl T5Attention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = Embedding::load(
|
||||
cfg.relative_attention_num_buckets,
|
||||
|
@ -2,6 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -138,35 +139,15 @@ impl Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
})
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
Ok(Self { weight, bias: None })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -258,35 +239,10 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
eps: 1e-5,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
@ -300,10 +256,10 @@ struct MultiHeadAttention {
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||
let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||
let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||
let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||
let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -364,20 +320,19 @@ struct ResidualAttentionBlock {
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn =
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
||||
let cross_attn_ln =
|
||||
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||
let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||
let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||
let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
@ -456,7 +411,7 @@ impl AudioEncoder {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
@ -503,7 +458,7 @@ impl TextDecoder {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
|
24
candle-nn/Cargo.toml
Normal file
24
candle-nn/Cargo.toml
Normal file
@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "candle-nn"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/LaurentMazare/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", default-features=false }
|
||||
thiserror = "1"
|
||||
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
34
candle-nn/src/layer_norm.rs
Normal file
34
candle-nn/src/layer_norm.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
5
candle-nn/src/lib.rs
Normal file
5
candle-nn/src/lib.rs
Normal file
@ -0,0 +1,5 @@
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
|
||||
pub use layer_norm::LayerNorm;
|
||||
pub use linear::Linear;
|
25
candle-nn/src/linear.rs
Normal file
25
candle-nn/src/linear.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use candle::Tensor;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user