mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -312,9 +312,11 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||
// TODO: Support a single dim as input?
|
||||
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
|
||||
Ok(PyTensor(
|
||||
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn sum_all(&self) -> PyResult<Self> {
|
||||
|
Reference in New Issue
Block a user