Better control on the optional dequantization in QMatMul (#1049)

* Cosmetic change to the quantized whisper model.

* Fix the dequantization.

* Add the dequantize all variable.
This commit is contained in:
Laurent Mazare
2023-10-07 10:16:18 +01:00
committed by GitHub
parent 955e00b2e8
commit aa53368aeb
2 changed files with 33 additions and 13 deletions

View File

@ -237,14 +237,28 @@ pub enum QMatMul {
Tensor(Tensor), Tensor(Tensor),
} }
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul { impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> { pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let t = match qtensor.dtype() { let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => { GgmlDType::F32 | GgmlDType::F16 => true,
let tensor = qtensor.dequantize(&Device::Cpu)?; _ => DEQUANTIZE_ALL.with(|b| *b),
Self::Tensor(tensor) };
} let t = if dequantize {
_ => Self::QTensor(qtensor), let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
} else {
Self::QTensor(qtensor)
}; };
Ok(t) Ok(t)
} }
@ -297,7 +311,14 @@ impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self { match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(t) => xs.matmul(&t.t()?), Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
} }
} }
} }

View File

@ -216,12 +216,11 @@ impl ResidualAttentionBlock {
if let Some((attn, ln)) = &mut self.cross_attn { if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
} }
let mlp = self.mlp_linear2.forward( let mlp = x
&self .apply(&self.mlp_ln)?
.mlp_linear1 .apply(&self.mlp_linear1)?
.forward(&self.mlp_ln.forward(&x)?)? .gelu()?
.gelu()?, .apply(&self.mlp_linear2)?;
)?;
x + mlp x + mlp
} }
} }