mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Support for quantized tensors in the python api. (#706)
* Add more pyo3 support. * Add some support for quantized tensors in pyo3. * Add an arc layer on qmatmul. * Add the quantized matmul. * Quantization support. * More quantization support. * Test the python quantization.
This commit is contained in:
@ -230,12 +230,20 @@ impl QTensor {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(QTensor);
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor {
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(&self.0)
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
@ -3,10 +3,11 @@
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyTuple;
|
||||
use std::sync::Arc;
|
||||
|
||||
use half::{bf16, f16};
|
||||
|
||||
use ::candle::{DType, Device, Tensor, WithDType};
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
@ -261,6 +262,38 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
fn sin(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn cos(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn log(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sqr(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sqrt(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn recip(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn exp(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn powf(&self, p: f64) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
@ -344,8 +377,12 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||
// TODO: Support a single dim as input?
|
||||
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
|
||||
vec![dim]
|
||||
} else {
|
||||
dims.extract::<Vec<usize>>(py)?
|
||||
};
|
||||
Ok(PyTensor(
|
||||
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
|
||||
))
|
||||
@ -355,6 +392,13 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn mean_all(&self) -> PyResult<Self> {
|
||||
let elements = self.0.elem_count();
|
||||
let sum = self.0.sum_all().map_err(wrap_err)?;
|
||||
let mean = (sum / elements as f64).map_err(wrap_err)?;
|
||||
Ok(PyTensor(mean))
|
||||
}
|
||||
|
||||
fn flatten_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||
}
|
||||
@ -392,6 +436,32 @@ impl PyTensor {
|
||||
let device = device.as_device()?;
|
||||
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype {
|
||||
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||
"q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
|
||||
"q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
|
||||
"q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
|
||||
"q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
|
||||
"q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
|
||||
"q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
|
||||
"q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
|
||||
"q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
|
||||
"q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
|
||||
"f16" => quantized::QTensor::quantize::<f16>(self),
|
||||
"f32" => quantized::QTensor::quantize::<f32>(self),
|
||||
dt => {
|
||||
return Err(PyErr::new::<PyValueError, _>(format!(
|
||||
"unknown quantized-dtype {dt}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate the tensors across one axis.
|
||||
@ -464,9 +534,90 @@ fn zeros(
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[pyclass(name = "QTensor")]
|
||||
struct PyQTensor(Arc<QTensor>);
|
||||
|
||||
impl std::ops::Deref for PyQTensor {
|
||||
type Target = QTensor;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyQTensor {
|
||||
#[getter]
|
||||
fn ggml_dtype(&self) -> String {
|
||||
format!("{:?}", self.0.dtype())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn rank(&self) -> usize {
|
||||
self.0.rank()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.shape().dims()).to_object(py)
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("{:?}", self.0)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
fn dequantize(&self) -> PyResult<PyTensor> {
|
||||
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
||||
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
|
||||
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
||||
Ok(PyTensor(res))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn cuda_is_available() -> bool {
|
||||
::candle::utils::cuda_is_available()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn has_accelerate() -> bool {
|
||||
::candle::utils::has_accelerate()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn has_mkl() -> bool {
|
||||
::candle::utils::has_mkl()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn get_num_threads() -> usize {
|
||||
::candle::utils::get_num_threads()
|
||||
}
|
||||
|
||||
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
candle_utils(py, utils)?;
|
||||
m.add_submodule(utils)?;
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
m.add_class::<PyDType>()?;
|
||||
m.add("u8", PyDType(DType::U8))?;
|
||||
m.add("u32", PyDType(DType::U32))?;
|
||||
|
@ -33,3 +33,9 @@ print(t.to_dtype("u8"))
|
||||
t = candle.randn((5, 3))
|
||||
print(t)
|
||||
print(t.dtype)
|
||||
|
||||
t = candle.randn((16, 256))
|
||||
quant_t = t.quantize("q6k")
|
||||
dequant_t = quant_t.dequantize()
|
||||
diff2 = (t - dequant_t).sqr()
|
||||
print(diff2.mean_all())
|
||||
|
Reference in New Issue
Block a user