diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 94e6bd23..58f261b4 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -237,14 +237,28 @@ pub enum QMatMul { Tensor(Tensor), } +thread_local! { + static DEQUANTIZE_ALL: bool = { + match std::env::var("CANDLE_DEQUANTIZE_ALL") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc) -> Result { - let t = match qtensor.dtype() { - GgmlDType::F32 | GgmlDType::F16 => { - let tensor = qtensor.dequantize(&Device::Cpu)?; - Self::Tensor(tensor) - } - _ => Self::QTensor(qtensor), + let dequantize = match qtensor.dtype() { + GgmlDType::F32 | GgmlDType::F16 => true, + _ => DEQUANTIZE_ALL.with(|b| *b), + }; + let t = if dequantize { + let tensor = qtensor.dequantize(&Device::Cpu)?; + Self::Tensor(tensor) + } else { + Self::QTensor(qtensor) }; Ok(t) } @@ -297,7 +311,14 @@ impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result { match self { Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), - Self::Tensor(t) => xs.matmul(&t.t()?), + Self::Tensor(w) => { + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.matmul(&w) + } } } } diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index 59942cbf..26ec6c94 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -216,12 +216,11 @@ impl ResidualAttentionBlock { if let Some((attn, ln)) = &mut self.cross_attn { x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; } - let mlp = self.mlp_linear2.forward( - &self - .mlp_linear1 - .forward(&self.mlp_ln.forward(&x)?)? - .gelu()?, - )?; + let mlp = x + .apply(&self.mlp_ln)? + .apply(&self.mlp_linear1)? + .gelu()? + .apply(&self.mlp_linear2)?; x + mlp } }