Improve the quantized whisper setup. (#1018)

* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
This commit is contained in:
Laurent Mazare
2023-10-02 17:17:46 +01:00
committed by GitHub
parent e04c789230
commit 089fc3b584
8 changed files with 66 additions and 49 deletions

View File

@ -232,19 +232,25 @@ impl QTensor {
}
#[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>);
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let t = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
}
_ => Self::QTensor(qtensor),
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
}
@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(self.0.as_ref())
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(t) => xs.matmul(&t.t()?),
}
}
}

View File

@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -576,7 +576,7 @@ fn quantized_matmul_q2k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -602,7 +602,7 @@ fn quantized_matmul_q3k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -628,7 +628,7 @@ fn quantized_matmul_q4k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -654,7 +654,7 @@ fn quantized_matmul_q5k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -681,7 +681,7 @@ fn quantized_matmul_q6k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);