mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the index-select op. (#209)
* Add the index-select op. * Cpu implementation of index-select. * Add the cpu implementation for index-select.
This commit is contained in:
@ -960,6 +960,33 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
[l] => *l,
|
||||
_ => Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
let storage = self.storage().index_select(
|
||||
&indexes.storage(),
|
||||
self.layout(),
|
||||
indexes.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = indexes_len;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
/// Returns an iterator over position of the elements in the storage when ranging over the
|
||||
/// index tuples in lexicographic order.
|
||||
pub fn strided_index(&self) -> crate::StridedIndex {
|
||||
|
Reference in New Issue
Block a user