mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Better control on the optional dequantization in QMatMul (#1049)
* Cosmetic change to the quantized whisper model. * Fix the dequantization. * Add the dequantize all variable.
This commit is contained in:
@ -237,14 +237,28 @@ pub enum QMatMul {
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let t = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
}
|
||||
_ => Self::QTensor(qtensor),
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
@ -297,7 +311,14 @@ impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(t) => xs.matmul(&t.t()?),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -216,12 +216,11 @@ impl ResidualAttentionBlock {
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
let mlp = x
|
||||
.apply(&self.mlp_ln)?
|
||||
.apply(&self.mlp_linear1)?
|
||||
.gelu()?
|
||||
.apply(&self.mlp_linear2)?;
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user