Add matmul/where_cond.

This commit is contained in:
laurent
2023-07-02 07:34:14 +01:00
parent 9a9858bbe0
commit 5b8c6764b0

View File

@ -149,6 +149,16 @@ impl PyTensor {
self.__repr__() self.__repr__()
} }
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
Ok(PyTensor(
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
))
}
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> { fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() { let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)? (&self.0 + &rhs.0).map_err(wrap_err)?
@ -219,6 +229,15 @@ impl PyTensor {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
} }
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
// TODO: Support a single dim as input?
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
}
fn sum_all(&self) -> PyResult<Self> { fn sum_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
} }