mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add matmul/where_cond.
This commit is contained in:
@ -149,6 +149,16 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
@ -219,6 +229,15 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||
// TODO: Support a single dim as input?
|
||||
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
Reference in New Issue
Block a user