Add a simple implementation of cumsum. (#1334)

* Add a simple implementation of cumsum.

* Add another test.
This commit is contained in:
Laurent Mazare
2023-11-15 21:11:15 +00:00
committed by GitHub
parent 347e31c9ff
commit c6763e3b41
2 changed files with 49 additions and 0 deletions

View File

@ -2474,6 +2474,28 @@ impl Tensor {
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
/// Returns the cumulative sum of elements of the input tensor summed over the specified
/// dimension.
///
/// This operation is most efficient when dim is the last dimension of the tensor.
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "cumsum")?;
let rank = self.rank();
if rank == 0 {
return Ok(self.clone());
}
let n_axis = self.dim(dim)?;
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
if rank == 1 {
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
} else {
let last = rank - 1;
let t = self.transpose(dim, last)?;
let t = t.broadcast_matmul(&triu)?;
t.transpose(dim, last)
}
}
}
macro_rules! bin_trait {