mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Support for quantized tensors in the python api. (#706)
* Add more pyo3 support. * Add some support for quantized tensors in pyo3. * Add an arc layer on qmatmul. * Add the quantized matmul. * Quantization support. * More quantization support. * Test the python quantization.
This commit is contained in:
@ -230,12 +230,20 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct QMatMul(QTensor);
|
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||||
Self(qtensor)
|
Self(qtensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||||
|
Self(std::sync::Arc::new(qtensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for QTensor {
|
impl crate::CustomOp1 for QTensor {
|
||||||
@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.apply_op1_no_bwd(&self.0)
|
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::PyTuple;
|
use pyo3::types::PyTuple;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
use ::candle::{DType, Device, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||||
|
|
||||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
@ -261,6 +262,38 @@ impl PyTensor {
|
|||||||
self.__repr__()
|
self.__repr__()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sin(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cos(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sqr(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sqrt(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recip(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exp(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, p: f64) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
@ -344,8 +377,12 @@ impl PyTensor {
|
|||||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
|
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||||
// TODO: Support a single dim as input?
|
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
|
||||||
|
vec![dim]
|
||||||
|
} else {
|
||||||
|
dims.extract::<Vec<usize>>(py)?
|
||||||
|
};
|
||||||
Ok(PyTensor(
|
Ok(PyTensor(
|
||||||
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
|
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
|
||||||
))
|
))
|
||||||
@ -355,6 +392,13 @@ impl PyTensor {
|
|||||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mean_all(&self) -> PyResult<Self> {
|
||||||
|
let elements = self.0.elem_count();
|
||||||
|
let sum = self.0.sum_all().map_err(wrap_err)?;
|
||||||
|
let mean = (sum / elements as f64).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(mean))
|
||||||
|
}
|
||||||
|
|
||||||
fn flatten_all(&self) -> PyResult<Self> {
|
fn flatten_all(&self) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
@ -392,6 +436,32 @@ impl PyTensor {
|
|||||||
let device = device.as_device()?;
|
let device = device.as_device()?;
|
||||||
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||||
|
use ::candle::quantized;
|
||||||
|
let res = match quantized_dtype {
|
||||||
|
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||||
|
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||||
|
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||||
|
"q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
|
||||||
|
"q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
|
||||||
|
"q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
|
||||||
|
"q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
|
||||||
|
"q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
|
||||||
|
"q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
|
||||||
|
"q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
|
||||||
|
"q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
|
||||||
|
"q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
|
||||||
|
"f16" => quantized::QTensor::quantize::<f16>(self),
|
||||||
|
"f32" => quantized::QTensor::quantize::<f32>(self),
|
||||||
|
dt => {
|
||||||
|
return Err(PyErr::new::<PyValueError, _>(format!(
|
||||||
|
"unknown quantized-dtype {dt}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate the tensors across one axis.
|
/// Concatenate the tensors across one axis.
|
||||||
@ -464,9 +534,90 @@ fn zeros(
|
|||||||
Ok(PyTensor(tensor))
|
Ok(PyTensor(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[pyclass(name = "QTensor")]
|
||||||
|
struct PyQTensor(Arc<QTensor>);
|
||||||
|
|
||||||
|
impl std::ops::Deref for PyQTensor {
|
||||||
|
type Target = QTensor;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyQTensor {
|
||||||
|
#[getter]
|
||||||
|
fn ggml_dtype(&self) -> String {
|
||||||
|
format!("{:?}", self.0.dtype())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn rank(&self) -> usize {
|
||||||
|
self.0.rank()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||||
|
PyTuple::new(py, self.0.shape().dims()).to_object(py)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __repr__(&self) -> String {
|
||||||
|
format!("{:?}", self.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __str__(&self) -> String {
|
||||||
|
self.__repr__()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize(&self) -> PyResult<PyTensor> {
|
||||||
|
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
||||||
|
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
|
||||||
|
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn cuda_is_available() -> bool {
|
||||||
|
::candle::utils::cuda_is_available()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn has_accelerate() -> bool {
|
||||||
|
::candle::utils::has_accelerate()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn has_mkl() -> bool {
|
||||||
|
::candle::utils::has_mkl()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn get_num_threads() -> usize {
|
||||||
|
::candle::utils::get_num_threads()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
|
let utils = PyModule::new(py, "utils")?;
|
||||||
|
candle_utils(py, utils)?;
|
||||||
|
m.add_submodule(utils)?;
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
|
m.add_class::<PyQTensor>()?;
|
||||||
m.add_class::<PyDType>()?;
|
m.add_class::<PyDType>()?;
|
||||||
m.add("u8", PyDType(DType::U8))?;
|
m.add("u8", PyDType(DType::U8))?;
|
||||||
m.add("u32", PyDType(DType::U32))?;
|
m.add("u32", PyDType(DType::U32))?;
|
||||||
|
@ -33,3 +33,9 @@ print(t.to_dtype("u8"))
|
|||||||
t = candle.randn((5, 3))
|
t = candle.randn((5, 3))
|
||||||
print(t)
|
print(t)
|
||||||
print(t.dtype)
|
print(t.dtype)
|
||||||
|
|
||||||
|
t = candle.randn((16, 256))
|
||||||
|
quant_t = t.quantize("q6k")
|
||||||
|
dequant_t = quant_t.dequantize()
|
||||||
|
diff2 = (t - dequant_t).sqr()
|
||||||
|
print(diff2.mean_all())
|
||||||
|
Reference in New Issue
Block a user