Better control on the optional dequantization in QMatMul (#1049)

* Cosmetic change to the quantized whisper model.

* Fix the dequantization.

* Add the dequantize all variable.
This commit is contained in:
Laurent Mazare
2023-10-07 10:16:18 +01:00
committed by GitHub
parent 955e00b2e8
commit aa53368aeb
2 changed files with 33 additions and 13 deletions

View File

@ -237,14 +237,28 @@ pub enum QMatMul {
Tensor(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let t = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
}
_ => Self::QTensor(qtensor),
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
}
@ -297,7 +311,14 @@ impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(t) => xs.matmul(&t.t()?),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
}
}
}

View File

@ -216,12 +216,11 @@ impl ResidualAttentionBlock {
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
let mlp = x
.apply(&self.mlp_ln)?
.apply(&self.mlp_linear1)?
.gelu()?
.apply(&self.mlp_linear2)?;
x + mlp
}
}