Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -312,9 +312,11 @@ impl PyTensor {
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
// TODO: Support a single dim as input?
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
Ok(PyTensor(
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
))
}
fn sum_all(&self) -> PyResult<Self> {