mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add. * More cpu support for index-add. * Use index-add in the backprop.
This commit is contained in:
@ -947,6 +947,30 @@ impl Tensor {
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
// the target tensor self)
|
||||
mismatch || source_dims[dim] != indexes.shape().r1()?
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "index-add",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
|
Reference in New Issue
Block a user