Quantized support for stable-lm2. (#1654)

* Quantized support for stable-lm2.

* Quantized support for v2-zephyr.
This commit is contained in:
Laurent Mazare
2024-02-04 11:57:05 +01:00
committed by GitHub
parent 58cc896e69
commit 50be8a98ba
3 changed files with 36 additions and 10 deletions

View File

@ -1,4 +1,4 @@
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm};
@ -67,9 +67,14 @@ impl Attention {
let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let linear_layer = if cfg.use_qkv_bias {
linear
} else {
linear_no_bias
};
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,