Start adding index-add.

This commit is contained in:
laurent
2023-07-21 20:12:48 +01:00
parent 5cc843550d
commit 27174a82aa
8 changed files with 97 additions and 3 deletions

View File

@ -945,6 +945,29 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let storage = self.storage().index_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexAdd(
self.clone(),
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {