mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add some testing for index-add (#237)
* Add some testing for index-add. * Fix the cpu implementation for index-add.
This commit is contained in:
@ -1010,18 +1010,26 @@ impl Tensor {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
// the target tensor self)
|
||||
mismatch || source_dims[dim] != indexes.dims1()?
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "index-add",
|
||||
op: "index-add (self, source)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
// the target tensor self)
|
||||
let indexes_len = indexes.dims1()?;
|
||||
if source_dims[dim] != indexes_len {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "index-add (ids, source))",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
|
Reference in New Issue
Block a user