mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
This commit is contained in:
@ -283,52 +283,38 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||||
);
|
);
|
||||||
|
|
||||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
if device.is_cpu() {
|
||||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||||
|
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||||
|
|
||||||
println!("y: {}", y.unsqueeze(1)?);
|
let z = Tensor::new(
|
||||||
#[rustfmt::skip]
|
&[
|
||||||
let z = Tensor::new(
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||||
&[
|
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||||
1_f32, 02., 03., 04.,
|
33., 34., 35., 36.,
|
||||||
05., 06., 07., 08.,
|
],
|
||||||
09., 10., 11., 12.,
|
device,
|
||||||
13., 14., 15., 16.,
|
)?;
|
||||||
17., 18., 19., 20.,
|
|
||||||
21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28.,
|
|
||||||
29., 30., 31., 32.,
|
|
||||||
33., 34., 35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(grad_x, 4)?,
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||||
|
);
|
||||||
println!("grad: {grad_x}");
|
}
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(grad_x, 4)?,
|
|
||||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
// manually checked: see comments
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
07., 08., 09., 10., 11., 12.,
|
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||||
13., 14., 15., 16., 17., 18.,
|
35., 36.,
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
@ -359,15 +345,11 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
07., 08., 09., 10., 11., 12.,
|
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||||
13., 14., 15., 16., 17., 18.,
|
35., 36.,
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
|
@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT};
|
|||||||
pub use group_norm::{group_norm, GroupNorm};
|
pub use group_norm::{group_norm, GroupNorm};
|
||||||
pub use init::Init;
|
pub use init::Init;
|
||||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||||
pub use linear::{linear, linear_no_bias, Linear};
|
pub use linear::{linear, linear_b, linear_no_bias, Linear};
|
||||||
pub use ops::Dropout;
|
pub use ops::Dropout;
|
||||||
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||||
|
@ -57,21 +57,34 @@ impl super::Module for Linear {
|
|||||||
/// Create or initialize a new linear layer.
|
/// Create or initialize a new linear layer.
|
||||||
///
|
///
|
||||||
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
|
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
|
||||||
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||||
let bound = 1. / (in_dim as f64).sqrt();
|
let bound = 1. / (in_dim as f64).sqrt();
|
||||||
let init_bs = crate::Init::Uniform {
|
let init_bs = crate::Init::Uniform {
|
||||||
lo: -bound,
|
lo: -bound,
|
||||||
up: bound,
|
up: bound,
|
||||||
};
|
};
|
||||||
let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
|
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
|
||||||
Ok(Linear::new(ws, Some(bs)))
|
Ok(Linear::new(ws, Some(bs)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create or initialize a new linear layer without biases.
|
/// Create or initialize a new linear layer without biases.
|
||||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||||
Ok(Linear::new(ws, None))
|
Ok(Linear::new(ws, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn linear_b(
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
bias: bool,
|
||||||
|
vb: crate::VarBuilder,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
if bias {
|
||||||
|
linear(in_dim, out_dim, vb)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(in_dim, out_dim, vb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,15 +1,5 @@
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = if bias {
|
|
||||||
Some(vb.get(size2, "bias")?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Linear::new(weight, bias))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let weight = vb.get(size, "weight")?;
|
let weight = vb.get(size, "weight")?;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::models::with_tracing::Linear;
|
use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
@ -51,14 +51,6 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
if bias {
|
|
||||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
|
||||||
} else {
|
|
||||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
cache: Tensor,
|
cache: Tensor,
|
||||||
|
@ -1,18 +1,8 @@
|
|||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
const MAX_SEQ_LEN: usize = 5000;
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = if bias {
|
|
||||||
Some(vb.get(size2, "bias")?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Linear::new(weight, bias))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{linear_no_bias, Linear, VarBuilder};
|
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||||
|
|
||||||
fn default_max_position_embeddings() -> usize {
|
fn default_max_position_embeddings() -> usize {
|
||||||
4096
|
4096
|
||||||
@ -119,9 +119,9 @@ impl MLP {
|
|||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
let intermediate_sz = cfg.intermediate_size;
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
gate_proj,
|
gate_proj,
|
||||||
up_proj,
|
up_proj,
|
||||||
@ -160,10 +160,11 @@ impl Attention {
|
|||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let num_kv_groups = num_heads / num_kv_heads;
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
let head_dim = cfg.head_dim;
|
let head_dim = cfg.head_dim;
|
||||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
let bias = cfg.attention_bias;
|
||||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
k_proj,
|
k_proj,
|
||||||
|
@ -47,6 +47,12 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Reference in New Issue
Block a user