mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup. * Fix the config file paths. * Use the standard matmul where possible.
This commit is contained in:
@ -232,19 +232,25 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
pub enum QMatMul {
|
||||||
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
|
Tensor(Tensor),
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
Self(qtensor)
|
let t = match qtensor.dtype() {
|
||||||
|
GgmlDType::F32 | GgmlDType::F16 => {
|
||||||
|
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||||
|
Self::Tensor(tensor)
|
||||||
|
}
|
||||||
|
_ => Self::QTensor(qtensor),
|
||||||
|
};
|
||||||
|
Ok(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
Self(std::sync::Arc::new(qtensor))
|
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||||
}
|
|
||||||
|
|
||||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
|
||||||
&self.0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
match self {
|
||||||
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
|
Self::Tensor(t) => xs.matmul(&t.t()?),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -576,7 +576,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -602,7 +602,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -628,7 +628,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -654,7 +654,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -681,7 +681,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
@ -484,17 +484,25 @@ fn main() -> Result<()> {
|
|||||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||||
dataset.get("samples_jfk.wav")?
|
dataset.get("samples_jfk.wav")?
|
||||||
};
|
};
|
||||||
let config = if args.quantized {
|
let (config, tokenizer, model) = if args.quantized {
|
||||||
repo.get("config-tiny.json")?
|
let ext = match args.model {
|
||||||
|
WhichModel::TinyEn => "tiny-en",
|
||||||
|
WhichModel::Tiny => "tiny",
|
||||||
|
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||||
|
};
|
||||||
|
(
|
||||||
|
repo.get(&format!("config-{ext}.json"))?,
|
||||||
|
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||||
|
repo.get(&format!("model-{ext}-q40.gguf"))?,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
repo.get("config.json")?
|
(
|
||||||
|
repo.get("config.json")?,
|
||||||
|
repo.get("tokenizer.json")?,
|
||||||
|
repo.get("model.safetensors")?,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let model = if args.quantized {
|
(config, tokenizer, model, sample)
|
||||||
repo.get("model-tiny-q40.gguf")?
|
|
||||||
} else {
|
|
||||||
repo.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, repo.get("tokenizer.json")?, model, sample)
|
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ impl Benchmark for QMatMul {
|
|||||||
fn preprocess() -> Result<Self::PreProcessData> {
|
fn preprocess() -> Result<Self::PreProcessData> {
|
||||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||||
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
|
||||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||||
Ok((mm, arg))
|
Ok((mm, arg))
|
||||||
}
|
}
|
||||||
|
@ -867,7 +867,7 @@ impl PyQTensor {
|
|||||||
/// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
/// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
||||||
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
|
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?;
|
||||||
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(res))
|
Ok(PyTensor(res))
|
||||||
}
|
}
|
||||||
|
@ -33,10 +33,10 @@ struct QMatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
fn from_qtensor(qtensor: QTensor) -> Self {
|
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
Self { inner, span }
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
@ -217,14 +217,14 @@ impl ModelWeights {
|
|||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
layers.push(LayerWeights {
|
layers.push(LayerWeights {
|
||||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
||||||
n_head: ct.hparams.n_head as usize,
|
n_head: ct.hparams.n_head as usize,
|
||||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||||
@ -243,7 +243,7 @@ impl ModelWeights {
|
|||||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
output: QMatMul::from_qtensor(output),
|
output: QMatMul::from_qtensor(output)?,
|
||||||
masks: HashMap::new(),
|
masks: HashMap::new(),
|
||||||
span,
|
span,
|
||||||
span_output,
|
span_output,
|
||||||
@ -294,14 +294,14 @@ impl ModelWeights {
|
|||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
layers.push(LayerWeights {
|
layers.push(LayerWeights {
|
||||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
||||||
n_head: head_count,
|
n_head: head_count,
|
||||||
n_kv_head: head_count_kv,
|
n_kv_head: head_count_kv,
|
||||||
@ -320,7 +320,7 @@ impl ModelWeights {
|
|||||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
output: QMatMul::from_qtensor(output),
|
output: QMatMul::from_qtensor(output)?,
|
||||||
masks: HashMap::new(),
|
masks: HashMap::new(),
|
||||||
span,
|
span,
|
||||||
span_output,
|
span_output,
|
||||||
|
@ -90,7 +90,7 @@ impl QMatMul {
|
|||||||
vb: crate::quantized_var_builder::VarBuilder,
|
vb: crate::quantized_var_builder::VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
|
Reference in New Issue
Block a user