mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -138,35 +139,15 @@ impl Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
Ok(Self { weight, bias: None })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -258,35 +239,10 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
eps: 1e-5,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
@ -300,10 +256,10 @@ struct MultiHeadAttention {
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||
let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||
let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||
let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||
let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -364,20 +320,19 @@ struct ResidualAttentionBlock {
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn =
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
||||
let cross_attn_ln =
|
||||
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||
let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||
let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||
let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
@ -456,7 +411,7 @@ impl AudioEncoder {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
@ -503,7 +458,7 @@ impl TextDecoder {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
|
Reference in New Issue
Block a user