Relax the requirements on CustomOp. (#486)

* Relax the requirements on CustomOp.

* Simplify the custom-ops when no backward is required.
This commit is contained in:
Laurent Mazare
2023-08-17 11:12:05 +01:00
committed by GitHub
parent d32e8199cd
commit 03be33eea4
8 changed files with 81 additions and 31 deletions

View File

@ -147,11 +147,11 @@ impl QTensor {
}
}
pub struct QMatMul(std::sync::Arc<Box<dyn crate::CustomOp1>>);
pub struct QMatMul(QTensor);
impl QMatMul {
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(Box::new(qtensor)))
Self(qtensor)
}
}
@ -196,6 +196,6 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.custom_op1_arc(self.0.clone())
xs.apply_op1_no_bwd(&self.0)
}
}