From c753f72c8552ba3e108bd3f1a04971e8abbf3012 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 22 Feb 2024 09:35:28 +0100 Subject: [PATCH] Support for attention bias in gemma + refactor things a bit. (#1744) * Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests. --- candle-core/tests/grad_tests.rs | 68 +++++++------------ candle-nn/src/lib.rs | 2 +- candle-nn/src/linear.rs | 23 +++++-- candle-transformers/src/models/bigcode.rs | 12 +--- candle-transformers/src/models/chatglm.rs | 10 +-- candle-transformers/src/models/falcon.rs | 12 +--- candle-transformers/src/models/gemma.rs | 17 ++--- .../src/models/with_tracing.rs | 6 ++ 8 files changed, 62 insertions(+), 88 deletions(-) diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 4fbb21a7..a4d81618 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -283,52 +283,38 @@ fn unary_grad(device: &Device) -> Result<()> { [1.0881, 0.9277, 1.0527, 0.5747], ); - let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?; - let y = x.interpolate1d(12)?.reshape(36)?; + if device.is_cpu() { + let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?; + let y = x.interpolate1d(12)?.reshape(36)?; - println!("y: {}", y.unsqueeze(1)?); - #[rustfmt::skip] - let z = Tensor::new( - &[ - 1_f32, 02., 03., 04., - 05., 06., 07., 08., - 09., 10., 11., 12., - 13., 14., 15., 16., - 17., 18., 19., 20., - 21., 22., 23., 24., - 25., 26., 27., 28., - 29., 30., 31., 32., - 33., 34., 35., 36., - ], - device, - )?; + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., + 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., + ], + device, + )?; - let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + let grads = loss.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; - let grads = loss.backward()?; - - let grad_x = grads.get(&x).context("no grad for x")?; - - println!("grad: {grad_x}"); - - assert_eq!( - test_utils::to_vec3_round(grad_x, 4)?, - [[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]] - ); + assert_eq!( + test_utils::to_vec3_round(grad_x, 4)?, + [[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]] + ); + } // manually checked: see comments let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let y = x.interpolate2d(6, 6)?.reshape(36)?; - #[rustfmt::skip] let z = Tensor::new( &[ - 1_f32, 02., 03., 04., 05., 06., - 07., 08., 09., 10., 11., 12., - 13., 14., 15., 16., 17., 18., - 19., 20., 21., 22., 23., 24., - 25., 26., 27., 28., 29., 30., - 31., 32., 33., 34., 35., 36., + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., ], device, )?; @@ -359,15 +345,11 @@ fn unary_grad(device: &Device) -> Result<()> { let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; let y = x.interpolate2d(6, 6)?.reshape(36)?; - #[rustfmt::skip] let z = Tensor::new( &[ - 1_f32, 02., 03., 04., 05., 06., - 07., 08., 09., 10., 11., 12., - 13., 14., 15., 16., 17., 18., - 19., 20., 21., 22., 23., 24., - 25., 26., 27., 28., 29., 30., - 31., 32., 33., 34., 35., 36., + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., ], device, )?; diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 3d0e6939..1bcb78d9 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; -pub use linear::{linear, linear_no_bias, Linear}; +pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 59a4db8a..96409042 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -57,21 +57,34 @@ impl super::Module for Linear { /// Create or initialize a new linear layer. /// /// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. -pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { +pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let bound = 1. / (in_dim as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_dim, "bias", init_bs)?; + let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } /// Create or initialize a new linear layer without biases. -pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(Linear::new(ws, None)) } + +pub fn linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: crate::VarBuilder, +) -> Result { + if bias { + linear(in_dim, out_dim, vb) + } else { + linear_no_bias(in_dim, out_dim, vb) + } +} diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index e69f08c8..2e1bbd37 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,15 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; - -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} +use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 95466b34..0686b34e 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::Linear; +use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -51,14 +51,6 @@ impl Config { } } -fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result { - if bias { - crate::models::with_tracing::linear(in_dim, out_dim, vb) - } else { - crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { cache: Tensor, diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index ef5a92fc..86cf8451 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,18 +1,8 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index e2be8406..282d5eb2 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_no_bias, Linear, VarBuilder}; +use candle_nn::{linear_b as linear, Linear, VarBuilder}; fn default_max_position_embeddings() -> usize { 4096 @@ -119,9 +119,9 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; + let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; + let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; Ok(Self { gate_proj, up_proj, @@ -160,10 +160,11 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; let head_dim = cfg.head_dim; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let bias = cfg.attention_bias; + let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; + let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; Ok(Self { q_proj, k_proj, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 53e21551..383ae71c 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -47,6 +47,12 @@ impl Linear { } } +pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result { + let inner = candle_nn::linear_b(d1, d2, b, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result { let inner = candle_nn::linear(d1, d2, vb)?; let span = tracing::span!(tracing::Level::TRACE, "linear");