mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
||||
use candle_nn::{LayerNorm, Linear};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { weight, bias })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -378,14 +333,14 @@ impl FalconAttention {
|
||||
} else {
|
||||
3 * hidden_size
|
||||
};
|
||||
let query_key_value = Linear::load(
|
||||
let query_key_value = linear(
|
||||
hidden_size,
|
||||
qkv_out_dim,
|
||||
cfg.bias,
|
||||
&format!("{p}.query_key_value"),
|
||||
vb,
|
||||
)?;
|
||||
let dense = Linear::load(
|
||||
let dense = linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
cfg.bias,
|
||||
@ -497,8 +452,8 @@ impl FalconMlp {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let b = cfg.bias;
|
||||
let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||
let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||
let dropout = Dropout::new(cfg.hidden_dropout);
|
||||
Ok(Self {
|
||||
dense_h_to_4h,
|
||||
@ -526,7 +481,7 @@ struct FalconDecoderLayer {
|
||||
impl FalconDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
||||
let inp_layernorm = LayerNorm::load(
|
||||
let inp_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.input_layernorm"),
|
||||
@ -536,7 +491,7 @@ impl FalconDecoderLayer {
|
||||
let post_attention_layernorm = if cfg.parallel_attn {
|
||||
None
|
||||
} else {
|
||||
let ln = LayerNorm::load(
|
||||
let ln = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.post_attention_layernorm"),
|
||||
@ -617,13 +572,13 @@ impl Falcon {
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = LayerNorm::load(
|
||||
let ln_f = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
"transformer.ln_f",
|
||||
vb,
|
||||
)?;
|
||||
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
blocks,
|
||||
|
Reference in New Issue
Block a user