mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
This commit is contained in:
@ -1,15 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::models::with_tracing::Linear;
|
||||
use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -51,14 +51,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
if bias {
|
||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
cache: Tensor,
|
||||
|
@ -1,18 +1,8 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_no_bias, Linear, VarBuilder};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
@ -119,9 +119,9 @@ impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
@ -160,10 +160,11 @@ impl Attention {
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
|
@ -47,6 +47,12 @@ impl Linear {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
|
Reference in New Issue
Block a user