Add the index-select op. (#209)

* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
This commit is contained in:
Laurent Mazare
2023-07-20 15:01:03 +02:00
committed by GitHub
parent 2a8f28d687
commit fa08fb3126
10 changed files with 168 additions and 20 deletions

View File

@ -960,6 +960,33 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {
[l] => *l,
_ => Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
op: "index-select",
}
.bt())?,
};
let storage = self.storage().index_select(
&indexes.storage(),
self.layout(),
indexes.layout(),
dim,
)?;
let mut dims = self.dims().to_vec();
dims[dim] = indexes_len;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, dims, op, false))
}
/// Returns an iterator over position of the elements in the storage when ranging over the
/// index tuples in lexicographic order.
pub fn strided_index(&self) -> crate::StridedIndex {