mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Better control on the optional dequantization in QMatMul (#1049)
* Cosmetic change to the quantized whisper model. * Fix the dequantization. * Add the dequantize all variable.
This commit is contained in:
@ -237,14 +237,28 @@ pub enum QMatMul {
|
|||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static DEQUANTIZE_ALL: bool = {
|
||||||
|
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
let t = match qtensor.dtype() {
|
let dequantize = match qtensor.dtype() {
|
||||||
GgmlDType::F32 | GgmlDType::F16 => {
|
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||||
Self::Tensor(tensor)
|
};
|
||||||
}
|
let t = if dequantize {
|
||||||
_ => Self::QTensor(qtensor),
|
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||||
|
Self::Tensor(tensor)
|
||||||
|
} else {
|
||||||
|
Self::QTensor(qtensor)
|
||||||
};
|
};
|
||||||
Ok(t)
|
Ok(t)
|
||||||
}
|
}
|
||||||
@ -297,7 +311,14 @@ impl QMatMul {
|
|||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(t) => xs.matmul(&t.t()?),
|
Self::Tensor(w) => {
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.matmul(&w)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,12 +216,11 @@ impl ResidualAttentionBlock {
|
|||||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||||
}
|
}
|
||||||
let mlp = self.mlp_linear2.forward(
|
let mlp = x
|
||||||
&self
|
.apply(&self.mlp_ln)?
|
||||||
.mlp_linear1
|
.apply(&self.mlp_linear1)?
|
||||||
.forward(&self.mlp_ln.forward(&x)?)?
|
.gelu()?
|
||||||
.gelu()?,
|
.apply(&self.mlp_linear2)?;
|
||||||
)?;
|
|
||||||
x + mlp
|
x + mlp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user