mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
@ -4,6 +4,7 @@ members = [
|
|||||||
"candle-examples",
|
"candle-examples",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
"candle-hub",
|
"candle-hub",
|
||||||
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -14,7 +14,8 @@ readme = "README.md"
|
|||||||
blas = { version = "0.22.0", optional = true }
|
blas = { version = "0.22.0", optional = true }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||||
# cudarc = { version = "0.9.12", optional = true, features = ["f16"] }
|
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
||||||
|
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
|
||||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
|
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
|
||||||
# TODO: Switch back to the official gemm implementation once something similar to
|
# TODO: Switch back to the official gemm implementation once something similar to
|
||||||
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
||||||
|
@ -12,6 +12,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", default-features=false }
|
candle = { path = "../candle-core", default-features=false }
|
||||||
|
candle-nn = { path = "../candle-nn", default-features=false }
|
||||||
serde = { version = "1.0.166", features = ["derive"] }
|
serde = { version = "1.0.166", features = ["derive"] }
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
@ -27,5 +28,5 @@ wav = "1.0.0"
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
|
@ -6,6 +6,7 @@ extern crate intel_mkl_src;
|
|||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
|
use candle_nn::{LayerNorm, Linear};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -194,29 +195,10 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Linear {
|
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||||
weight: Tensor,
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
bias: Tensor,
|
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||||
}
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
|
||||||
Self { weight, bias }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
|
||||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
|
||||||
Ok(Self::new(weight, bias))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let (bsize, _, _) = x.shape().r3()?;
|
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
let x = x.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Dropout {
|
struct Dropout {
|
||||||
@ -234,49 +216,24 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This layer norm version handles both weight and bias so removes the mean.
|
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||||
struct LayerNorm {
|
let (weight, bias) = match (
|
||||||
weight: Tensor,
|
vb.get(size, &format!("{p}.weight")),
|
||||||
bias: Tensor,
|
vb.get(size, &format!("{p}.bias")),
|
||||||
eps: f64,
|
) {
|
||||||
}
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
impl LayerNorm {
|
if let (Ok(weight), Ok(bias)) = (
|
||||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
vb.get(size, &format!("{p}.gamma")),
|
||||||
Self { weight, bias, eps }
|
vb.get(size, &format!("{p}.beta")),
|
||||||
}
|
) {
|
||||||
|
(weight, bias)
|
||||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
} else {
|
||||||
let (weight, bias) = match (
|
return Err(err.into());
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
Ok(Self { weight, bias, eps })
|
};
|
||||||
}
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
|
||||||
.broadcast_mul(&self.weight)?
|
|
||||||
.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||||
@ -310,7 +267,7 @@ impl BertEmbeddings {
|
|||||||
&format!("{p}.token_type_embeddings"),
|
&format!("{p}.token_type_embeddings"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
&format!("{p}.LayerNorm"),
|
||||||
@ -362,9 +319,9 @@ impl BertSelfAttention {
|
|||||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let hidden_size = config.hidden_size;
|
let hidden_size = config.hidden_size;
|
||||||
let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
||||||
let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
||||||
let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -414,13 +371,13 @@ struct BertSelfOutput {
|
|||||||
|
|
||||||
impl BertSelfOutput {
|
impl BertSelfOutput {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = Linear::load(
|
let dense = linear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
&format!("{p}.dense"),
|
&format!("{p}.dense"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
&format!("{p}.LayerNorm"),
|
||||||
@ -437,7 +394,7 @@ impl BertSelfOutput {
|
|||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,7 +429,7 @@ struct BertIntermediate {
|
|||||||
|
|
||||||
impl BertIntermediate {
|
impl BertIntermediate {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = Linear::load(
|
let dense = linear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
&format!("{p}.dense"),
|
&format!("{p}.dense"),
|
||||||
@ -500,13 +457,13 @@ struct BertOutput {
|
|||||||
|
|
||||||
impl BertOutput {
|
impl BertOutput {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = Linear::load(
|
let dense = linear(
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
&format!("{p}.dense"),
|
&format!("{p}.dense"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
&format!("{p}.LayerNorm"),
|
||||||
@ -523,7 +480,7 @@ impl BertOutput {
|
|||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
||||||
|
use candle_nn::{LayerNorm, Linear};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
const MAX_SEQ_LEN: usize = 5000;
|
||||||
@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||||
struct Linear {
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
weight: Tensor,
|
let bias = if bias {
|
||||||
bias: Option<Tensor>,
|
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Linear::new(weight, bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||||
fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
let (weight, bias) = match (
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
vb.get(size, &format!("{p}.weight")),
|
||||||
let bias = if bias {
|
vb.get(size, &format!("{p}.bias")),
|
||||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
) {
|
||||||
} else {
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
None
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
};
|
if let (Ok(weight), Ok(bias)) = (
|
||||||
Ok(Self { weight, bias })
|
vb.get(size, &format!("{p}.gamma")),
|
||||||
}
|
vb.get(size, &format!("{p}.beta")),
|
||||||
|
) {
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
(weight, bias)
|
||||||
let (bsize, _, _) = x.shape().r3()?;
|
} else {
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
return Err(err.into());
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct LayerNorm {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Tensor,
|
|
||||||
eps: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm {
|
|
||||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
|
||||||
Self { weight, bias, eps }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let (weight, bias) = match (
|
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
Ok(Self { weight, bias, eps })
|
};
|
||||||
}
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let dtype = x.dtype();
|
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
|
||||||
let x = x.to_dtype(DType::F32)?;
|
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
.broadcast_mul(&self.weight)?
|
|
||||||
.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -378,14 +333,14 @@ impl FalconAttention {
|
|||||||
} else {
|
} else {
|
||||||
3 * hidden_size
|
3 * hidden_size
|
||||||
};
|
};
|
||||||
let query_key_value = Linear::load(
|
let query_key_value = linear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
qkv_out_dim,
|
qkv_out_dim,
|
||||||
cfg.bias,
|
cfg.bias,
|
||||||
&format!("{p}.query_key_value"),
|
&format!("{p}.query_key_value"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let dense = Linear::load(
|
let dense = linear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
cfg.bias,
|
cfg.bias,
|
||||||
@ -497,8 +452,8 @@ impl FalconMlp {
|
|||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let b = cfg.bias;
|
let b = cfg.bias;
|
||||||
let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||||
let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||||
let dropout = Dropout::new(cfg.hidden_dropout);
|
let dropout = Dropout::new(cfg.hidden_dropout);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense_h_to_4h,
|
dense_h_to_4h,
|
||||||
@ -526,7 +481,7 @@ struct FalconDecoderLayer {
|
|||||||
impl FalconDecoderLayer {
|
impl FalconDecoderLayer {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
||||||
let inp_layernorm = LayerNorm::load(
|
let inp_layernorm = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
&format!("{p}.input_layernorm"),
|
&format!("{p}.input_layernorm"),
|
||||||
@ -536,7 +491,7 @@ impl FalconDecoderLayer {
|
|||||||
let post_attention_layernorm = if cfg.parallel_attn {
|
let post_attention_layernorm = if cfg.parallel_attn {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let ln = LayerNorm::load(
|
let ln = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
&format!("{p}.post_attention_layernorm"),
|
&format!("{p}.post_attention_layernorm"),
|
||||||
@ -617,13 +572,13 @@ impl Falcon {
|
|||||||
let blocks = (0..cfg.num_hidden_layers)
|
let blocks = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_f = LayerNorm::load(
|
let ln_f = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
"transformer.ln_f",
|
"transformer.ln_f",
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
word_embeddings,
|
word_embeddings,
|
||||||
blocks,
|
blocks,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
use crate::nn::{layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
||||||
use crate::{encodec_model, t5_model};
|
use crate::{encodec_model, t5_model};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle::{DType, Device, Tensor, D};
|
||||||
@ -146,10 +146,10 @@ impl MusicgenAttention {
|
|||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let head_dim = h / num_heads;
|
let head_dim = h / num_heads;
|
||||||
let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?;
|
let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
|
||||||
let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?;
|
let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
|
||||||
let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?;
|
let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
|
||||||
let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?;
|
let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
scaling: 1. / (head_dim as f64).sqrt(),
|
scaling: 1. / (head_dim as f64).sqrt(),
|
||||||
is_decoder: true,
|
is_decoder: true,
|
||||||
@ -213,14 +213,13 @@ impl MusicgenDecoderLayer {
|
|||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
||||||
let self_attn_layer_norm =
|
let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||||
LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
|
||||||
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
||||||
let encoder_attn_layer_norm =
|
let encoder_attn_layer_norm =
|
||||||
LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||||
let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
||||||
let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
||||||
let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
self_attn_layer_norm,
|
self_attn_layer_norm,
|
||||||
@ -290,7 +289,7 @@ impl MusicgenDecoder {
|
|||||||
let layers = (0..cfg.num_hidden_layers)
|
let layers = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
embed_positions,
|
embed_positions,
|
||||||
@ -341,7 +340,7 @@ impl MusicgenForCausalLM {
|
|||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
||||||
let lm_heads = (0..cfg.num_codebooks)
|
let lm_heads = (0..cfg.num_codebooks)
|
||||||
.map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
.map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
|
@ -63,80 +63,38 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub type Linear = candle_nn::Linear;
|
||||||
pub struct Linear {
|
|
||||||
weight: Tensor,
|
pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||||
bias: Option<Tensor>,
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
|
let bias = if bias {
|
||||||
|
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Linear::new(weight, bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
pub type LayerNorm = candle_nn::LayerNorm;
|
||||||
pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
|
||||||
let bias = if bias {
|
|
||||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Self { weight, bias })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||||
let (bsize, _, _) = x.shape().r3()?;
|
let (weight, bias) = match (
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
vb.get(size, &format!("{p}.weight")),
|
||||||
let x = x.matmul(&w)?;
|
vb.get(size, &format!("{p}.bias")),
|
||||||
match &self.bias {
|
) {
|
||||||
None => Ok(x),
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
Some(bias) => x.broadcast_add(bias),
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
}
|
if let (Ok(weight), Ok(bias)) = (
|
||||||
}
|
vb.get(size, &format!("{p}.gamma")),
|
||||||
}
|
vb.get(size, &format!("{p}.beta")),
|
||||||
|
) {
|
||||||
#[derive(Debug)]
|
(weight, bias)
|
||||||
pub struct LayerNorm {
|
} else {
|
||||||
weight: Tensor,
|
return Err(err.into());
|
||||||
bias: Tensor,
|
|
||||||
eps: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm {
|
|
||||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
|
||||||
Self { weight, bias, eps }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let (weight, bias) = match (
|
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
Ok(Self { weight, bias, eps })
|
};
|
||||||
}
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let dtype = x.dtype();
|
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
|
||||||
let x = x.to_dtype(DType::F32)?;
|
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
.broadcast_mul(&self.weight)?
|
|
||||||
.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// T5 Text Encoder
|
// T5 Text Encoder
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
|
|
||||||
@ -104,8 +104,8 @@ struct T5DenseActDense {
|
|||||||
|
|
||||||
impl T5DenseActDense {
|
impl T5DenseActDense {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||||
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi,
|
wi,
|
||||||
@ -154,10 +154,10 @@ struct T5Attention {
|
|||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||||
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||||
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||||
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||||
let relative_attention_bias = if h {
|
let relative_attention_bias = if h {
|
||||||
let emb = Embedding::load(
|
let emb = Embedding::load(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
// back when using RUST_LIB_BACKTRACE=1.
|
// back when using RUST_LIB_BACKTRACE=1.
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
|
use candle_nn::{LayerNorm, Linear};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -138,35 +139,15 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Linear {
|
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||||
weight: Tensor,
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
bias: Option<Tensor>,
|
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||||
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
Ok(Linear::new(weight, None))
|
||||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
|
||||||
Ok(Self {
|
|
||||||
weight,
|
|
||||||
bias: Some(bias),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
|
||||||
Ok(Self { weight, bias: None })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let (bsize, _, _) = x.shape().r3()?;
|
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@ -258,35 +239,10 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This layer norm version handles both weight and bias so removes the mean.
|
fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||||
struct LayerNorm {
|
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||||
weight: Tensor,
|
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||||
bias: Tensor,
|
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||||
eps: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm {
|
|
||||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
|
||||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
|
||||||
Ok(Self {
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
eps: 1e-5,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
|
||||||
.broadcast_mul(&self.weight)?
|
|
||||||
.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||||
@ -300,10 +256,10 @@ struct MultiHeadAttention {
|
|||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||||
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||||
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -364,20 +320,19 @@ struct ResidualAttentionBlock {
|
|||||||
impl ResidualAttentionBlock {
|
impl ResidualAttentionBlock {
|
||||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
||||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||||
let cross_attn = if ca {
|
let cross_attn = if ca {
|
||||||
let cross_attn =
|
let cross_attn =
|
||||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
||||||
let cross_attn_ln =
|
let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||||
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
|
||||||
Some((cross_attn, cross_attn_ln))
|
Some((cross_attn, cross_attn_ln))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let n_mlp = n_state * 4;
|
let n_mlp = n_state * 4;
|
||||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn,
|
attn,
|
||||||
attn_ln,
|
attn_ln,
|
||||||
@ -456,7 +411,7 @@ impl AudioEncoder {
|
|||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
@ -503,7 +458,7 @@ impl TextDecoder {
|
|||||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||||
let mask: Vec<_> = (0..n_ctx)
|
let mask: Vec<_> = (0..n_ctx)
|
||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
.collect();
|
.collect();
|
||||||
|
24
candle-nn/Cargo.toml
Normal file
24
candle-nn/Cargo.toml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-nn"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "Minimalist ML framework."
|
||||||
|
repository = "https://github.com/LaurentMazare/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle = { path = "../candle-core", default-features=false }
|
||||||
|
thiserror = "1"
|
||||||
|
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["cuda"]
|
||||||
|
cuda = ["candle/cuda"]
|
||||||
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
34
candle-nn/src/layer_norm.rs
Normal file
34
candle-nn/src/layer_norm.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use candle::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
// This layer norm version handles both weight and bias so removes the mean.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LayerNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNorm {
|
||||||
|
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||||
|
Self { weight, bias, eps }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||||
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let x = x_normed
|
||||||
|
.to_dtype(x_dtype)?
|
||||||
|
.broadcast_mul(&self.weight)?
|
||||||
|
.broadcast_add(&self.bias)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
5
candle-nn/src/lib.rs
Normal file
5
candle-nn/src/lib.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
mod layer_norm;
|
||||||
|
mod linear;
|
||||||
|
|
||||||
|
pub use layer_norm::LayerNorm;
|
||||||
|
pub use linear::Linear;
|
25
candle-nn/src/linear.rs
Normal file
25
candle-nn/src/linear.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
use candle::Tensor;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||||
|
Self { weight, bias }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let w = match x.dims() {
|
||||||
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => self.weight.t()?,
|
||||||
|
};
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user