Add the cuda dequantize f16 kernels. (#2137)

* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
Laurent Mazare
2024-04-28 20:05:05 +02:00
committed by GitHub
parent c68ed8963f
commit eb26e2467e
5 changed files with 317 additions and 55 deletions

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -360,9 +360,24 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
}
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize {
@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}
thread_local! {
@ -391,6 +407,17 @@ thread_local! {
}
}
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}