Support for quantized tensors in the python api. (#706)

* Add more pyo3 support.

* Add some support for quantized tensors in pyo3.

* Add an arc layer on qmatmul.

* Add the quantized matmul.

* Quantization support.

* More quantization support.

* Test the python quantization.
This commit is contained in:
Laurent Mazare
2023-09-01 16:53:42 +02:00
committed by GitHub
parent 237323c2bc
commit 2ed78ab336
3 changed files with 172 additions and 7 deletions

View File

@ -3,10 +3,11 @@
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;
use half::{bf16, f16};
use ::candle::{DType, Device, Tensor, WithDType};
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
@ -261,6 +262,38 @@ impl PyTensor {
self.__repr__()
}
fn sin(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
}
fn cos(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
}
fn log(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
}
fn sqr(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
}
fn sqrt(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
}
fn recip(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
}
fn exp(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
}
fn powf(&self, p: f64) -> PyResult<Self> {
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
}
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
@ -344,8 +377,12 @@ impl PyTensor {
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
// TODO: Support a single dim as input?
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
vec![dim]
} else {
dims.extract::<Vec<usize>>(py)?
};
Ok(PyTensor(
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
))
@ -355,6 +392,13 @@ impl PyTensor {
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
}
fn mean_all(&self) -> PyResult<Self> {
let elements = self.0.elem_count();
let sum = self.0.sum_all().map_err(wrap_err)?;
let mean = (sum / elements as f64).map_err(wrap_err)?;
Ok(PyTensor(mean))
}
fn flatten_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
}
@ -392,6 +436,32 @@ impl PyTensor {
let device = device.as_device()?;
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
}
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype {
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
"q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
"q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
"q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
"q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
"q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
"q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
"q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
"q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
"q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
"f16" => quantized::QTensor::quantize::<f16>(self),
"f32" => quantized::QTensor::quantize::<f32>(self),
dt => {
return Err(PyErr::new::<PyValueError, _>(format!(
"unknown quantized-dtype {dt}"
)))
}
};
Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?)))
}
}
/// Concatenate the tensors across one axis.
@ -464,9 +534,90 @@ fn zeros(
Ok(PyTensor(tensor))
}
#[derive(Debug)]
#[pyclass(name = "QTensor")]
struct PyQTensor(Arc<QTensor>);
impl std::ops::Deref for PyQTensor {
type Target = QTensor;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
#[pymethods]
impl PyQTensor {
#[getter]
fn ggml_dtype(&self) -> String {
format!("{:?}", self.0.dtype())
}
#[getter]
fn rank(&self) -> usize {
self.0.rank()
}
#[getter]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.shape().dims()).to_object(py)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
fn __str__(&self) -> String {
self.__repr__()
}
fn dequantize(&self) -> PyResult<PyTensor> {
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
Ok(PyTensor(res))
}
}
#[pyfunction]
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
Ok(())
}
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;

View File

@ -33,3 +33,9 @@ print(t.to_dtype("u8"))
t = candle.randn((5, 3))
print(t)
print(t.dtype)
t = candle.randn((16, 256))
quant_t = t.quantize("q6k")
dequant_t = quant_t.dequantize()
diff2 = (t - dequant_t).sqr()
print(diff2.mean_all())