mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add matmul/where_cond.
This commit is contained in:
@ -149,6 +149,16 @@ impl PyTensor {
|
|||||||
self.__repr__()
|
self.__repr__()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(
|
||||||
|
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||||
@ -219,6 +229,15 @@ impl PyTensor {
|
|||||||
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||||
|
// TODO: Support a single dim as input?
|
||||||
|
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
fn sum_all(&self) -> PyResult<Self> {
|
fn sum_all(&self) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user