mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_no_bias, Linear, VarBuilder};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
@ -119,9 +119,9 @@ impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
@ -160,10 +160,11 @@ impl Attention {
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
|
Reference in New Issue
Block a user