From 2ed78ab336c99080b1f8830f48ea40e2e1026249 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 1 Sep 2023 16:53:42 +0200 Subject: [PATCH] Support for quantized tensors in the python api. (#706) * Add more pyo3 support. * Add some support for quantized tensors in pyo3. * Add an arc layer on qmatmul. * Add the quantized matmul. * Quantization support. * More quantization support. * Test the python quantization. --- candle-core/src/quantized/mod.rs | 14 ++- candle-pyo3/src/lib.rs | 159 ++++++++++++++++++++++++++++++- candle-pyo3/test.py | 6 ++ 3 files changed, 172 insertions(+), 7 deletions(-) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d87d2d5a..5c2bb2b2 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -230,12 +230,20 @@ impl QTensor { } #[derive(Debug)] -pub struct QMatMul(QTensor); +pub struct QMatMul(std::sync::Arc); impl QMatMul { - pub fn from_qtensor(qtensor: QTensor) -> Self { + pub fn from_arc(qtensor: std::sync::Arc) -> Self { Self(qtensor) } + + pub fn from_qtensor(qtensor: QTensor) -> Self { + Self(std::sync::Arc::new(qtensor)) + } + + pub fn inner(&self) -> &std::sync::Arc { + &self.0 + } } impl crate::CustomOp1 for QTensor { @@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result { - xs.apply_op1_no_bwd(&self.0) + xs.apply_op1_no_bwd(self.0.as_ref()) } } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e93f1b17..7c28eb52 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -3,10 +3,11 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; +use std::sync::Arc; use half::{bf16, f16}; -use ::candle::{DType, Device, Tensor, WithDType}; +use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) @@ -261,6 +262,38 @@ impl PyTensor { self.__repr__() } + fn sin(&self) -> PyResult { + Ok(PyTensor(self.0.sin().map_err(wrap_err)?)) + } + + fn cos(&self) -> PyResult { + Ok(PyTensor(self.0.cos().map_err(wrap_err)?)) + } + + fn log(&self) -> PyResult { + Ok(PyTensor(self.0.log().map_err(wrap_err)?)) + } + + fn sqr(&self) -> PyResult { + Ok(PyTensor(self.0.sqr().map_err(wrap_err)?)) + } + + fn sqrt(&self) -> PyResult { + Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?)) + } + + fn recip(&self) -> PyResult { + Ok(PyTensor(self.0.recip().map_err(wrap_err)?)) + } + + fn exp(&self) -> PyResult { + Ok(PyTensor(self.0.exp().map_err(wrap_err)?)) + } + + fn powf(&self, p: f64) -> PyResult { + Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?)) + } + fn matmul(&self, rhs: &Self) -> PyResult { Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) } @@ -344,8 +377,12 @@ impl PyTensor { Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } - fn sum_keepdim(&self, dims: Vec) -> PyResult { - // TODO: Support a single dim as input? + fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult { + let dims = if let Ok(dim) = dims.extract::(py) { + vec![dim] + } else { + dims.extract::>(py)? + }; Ok(PyTensor( self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?, )) @@ -355,6 +392,13 @@ impl PyTensor { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) } + fn mean_all(&self) -> PyResult { + let elements = self.0.elem_count(); + let sum = self.0.sum_all().map_err(wrap_err)?; + let mean = (sum / elements as f64).map_err(wrap_err)?; + Ok(PyTensor(mean)) + } + fn flatten_all(&self) -> PyResult { Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) } @@ -392,6 +436,32 @@ impl PyTensor { let device = device.as_device()?; Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?)) } + + fn quantize(&self, quantized_dtype: &str) -> PyResult { + use ::candle::quantized; + let res = match quantized_dtype { + "q2k" => quantized::QTensor::quantize::(self), + "q3k" => quantized::QTensor::quantize::(self), + "q4_0" => quantized::QTensor::quantize::(self), + "q4_1" => quantized::QTensor::quantize::(self), + "q4k" => quantized::QTensor::quantize::(self), + "q5_0" => quantized::QTensor::quantize::(self), + "q5_1" => quantized::QTensor::quantize::(self), + "q5k" => quantized::QTensor::quantize::(self), + "q6k" => quantized::QTensor::quantize::(self), + "q8_0" => quantized::QTensor::quantize::(self), + "q8_1" => quantized::QTensor::quantize::(self), + "q8k" => quantized::QTensor::quantize::(self), + "f16" => quantized::QTensor::quantize::(self), + "f32" => quantized::QTensor::quantize::(self), + dt => { + return Err(PyErr::new::(format!( + "unknown quantized-dtype {dt}" + ))) + } + }; + Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?))) + } } /// Concatenate the tensors across one axis. @@ -464,9 +534,90 @@ fn zeros( Ok(PyTensor(tensor)) } +#[derive(Debug)] +#[pyclass(name = "QTensor")] +struct PyQTensor(Arc); + +impl std::ops::Deref for PyQTensor { + type Target = QTensor; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +#[pymethods] +impl PyQTensor { + #[getter] + fn ggml_dtype(&self) -> String { + format!("{:?}", self.0.dtype()) + } + + #[getter] + fn rank(&self) -> usize { + self.0.rank() + } + + #[getter] + fn shape(&self, py: Python<'_>) -> PyObject { + PyTuple::new(py, self.0.shape().dims()).to_object(py) + } + + fn __repr__(&self) -> String { + format!("{:?}", self.0) + } + + fn __str__(&self) -> String { + self.__repr__() + } + + fn dequantize(&self) -> PyResult { + let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?; + Ok(PyTensor(tensor)) + } + + fn matmul_t(&self, lhs: &PyTensor) -> PyResult { + let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()); + let res = qmatmul.forward(lhs).map_err(wrap_err)?; + Ok(PyTensor(res)) + } +} + +#[pyfunction] +fn cuda_is_available() -> bool { + ::candle::utils::cuda_is_available() +} + +#[pyfunction] +fn has_accelerate() -> bool { + ::candle::utils::has_accelerate() +} + +#[pyfunction] +fn has_mkl() -> bool { + ::candle::utils::has_mkl() +} + +#[pyfunction] +fn get_num_threads() -> usize { + ::candle::utils::get_num_threads() +} + +fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?; + m.add_function(wrap_pyfunction!(get_num_threads, m)?)?; + m.add_function(wrap_pyfunction!(has_accelerate, m)?)?; + m.add_function(wrap_pyfunction!(has_mkl, m)?)?; + Ok(()) +} + #[pymodule] -fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let utils = PyModule::new(py, "utils")?; + candle_utils(py, utils)?; + m.add_submodule(utils)?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add("u8", PyDType(DType::U8))?; m.add("u32", PyDType(DType::U32))?; diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 1711cdad..f76dee9b 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -33,3 +33,9 @@ print(t.to_dtype("u8")) t = candle.randn((5, 3)) print(t) print(t.dtype) + +t = candle.randn((16, 256)) +quant_t = t.quantize("q6k") +dequant_t = quant_t.dequantize() +diff2 = (t - dequant_t).sqr() +print(diff2.mean_all())