mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup. * Fix the config file paths. * Use the standard matmul where possible.
This commit is contained in:
@ -33,10 +33,10 @@ struct QMatMul {
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Self { inner, span }
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
@ -217,14 +217,14 @@ impl ModelWeights {
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
@ -243,7 +243,7 @@ impl ModelWeights {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
@ -294,14 +294,14 @@ impl ModelWeights {
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
@ -320,7 +320,7 @@ impl ModelWeights {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
|
@ -90,7 +90,7 @@ impl QMatMul {
|
||||
vb: crate::quantized_var_builder::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
Reference in New Issue
Block a user