mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Quantized support for stable-lm2. (#1654)
* Quantized support for stable-lm2. * Quantized support for v2-zephyr.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
||||
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, LayerNorm};
|
||||
@ -67,9 +67,14 @@ impl Attention {
|
||||
let head_dim = cfg.head_dim();
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let linear_layer = if cfg.use_qkv_bias {
|
||||
linear
|
||||
} else {
|
||||
linear_no_bias
|
||||
};
|
||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
|
Reference in New Issue
Block a user