Improve the quantized whisper setup. (#1018)

* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
This commit is contained in:
Laurent Mazare
2023-10-02 17:17:46 +01:00
committed by GitHub
parent e04c789230
commit 089fc3b584
8 changed files with 66 additions and 49 deletions

View File

@ -33,10 +33,10 @@ struct QMatMul {
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Self {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Self { inner, span }
Ok(Self { inner, span })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
@ -217,14 +217,14 @@ impl ModelWeights {
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@ -243,7 +243,7 @@ impl ModelWeights {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output: QMatMul::from_qtensor(output),
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,
@ -294,14 +294,14 @@ impl ModelWeights {
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@ -320,7 +320,7 @@ impl ModelWeights {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output),
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,

View File

@ -90,7 +90,7 @@ impl QMatMul {
vb: crate::quantized_var_builder::VarBuilder,
) -> Result<Self> {
let ws = vb.get((in_dim, out_dim), "weight")?;
let inner = candle::quantized::QMatMul::from_arc(ws);
let inner = candle::quantized::QMatMul::from_arc(ws)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}