mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a forward_via_f16 method to the qmatmul op. (#2138)
This commit is contained in:
@ -439,6 +439,25 @@ impl QMatMul {
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||
Self::TensorF16(t) => Ok(t.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w = self.dequantize_f16()?;
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
|
Reference in New Issue
Block a user