Support for quantized tensors in the python api. (#706)

* Add more pyo3 support.

* Add some support for quantized tensors in pyo3.

* Add an arc layer on qmatmul.

* Add the quantized matmul.

* Quantization support.

* More quantization support.

* Test the python quantization.
This commit is contained in:
Laurent Mazare
2023-09-01 16:53:42 +02:00
committed by GitHub
parent 237323c2bc
commit 2ed78ab336
3 changed files with 172 additions and 7 deletions

View File

@ -230,12 +230,20 @@ impl QTensor {
}
#[derive(Debug)]
pub struct QMatMul(QTensor);
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
pub fn from_qtensor(qtensor: QTensor) -> Self {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
}
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
}
}
impl crate::CustomOp1 for QTensor {
@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(&self.0)
xs.apply_op1_no_bwd(self.0.as_ref())
}
}