Polish the index-add op and use it in the index-select backprop (#218)

* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 06:31:46 +02:00
committed by GitHub
parent 27174a82aa
commit 6eeea1b04e
4 changed files with 108 additions and 36 deletions

View File

@ -947,6 +947,30 @@ impl Tensor {
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
true
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
}
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
// the target tensor self)
mismatch || source_dims[dim] != indexes.shape().r1()?
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "index-add",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
let storage = self.storage().index_add(
self.layout(),
&indexes.storage(),