mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Start adding index-add.
This commit is contained in:
@ -945,6 +945,29 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
|
Reference in New Issue
Block a user