Add the index-select op. (#209)

* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
This commit is contained in:
Laurent Mazare
2023-07-20 15:01:03 +02:00
committed by GitHub
parent 2a8f28d687
commit fa08fb3126
10 changed files with 168 additions and 20 deletions

View File

@ -515,6 +515,58 @@ impl Map1 for Affine {
}
}
struct IndexSelect<'a> {
ids: &'a [u32],
ids_l: &'a Layout,
dim: usize,
}
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
[n_ids] => *n_ids,
d => Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: d.len(),
shape: self.ids_l.shape().clone(),
})?,
};
let stride_ids = self.ids_l.stride()[0];
let mut dst_dims = layout.dims().to_vec();
let src_dim = dst_dims[dim];
dst_dims[dim] = n_ids;
let dst_len: usize = dst_dims.iter().product();
let left_len: usize = dst_dims[..dim].iter().product();
let right_len: usize = dst_dims[dim + 1..].iter().product();
let mut dst = vec![T::zero(); dst_len];
for left_i in 0..left_len {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
if index >= src_dim {
Err(Error::InvalidIndex {
index,
src_size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
let start_dst_idx = start_dst_idx + i * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
Ok(dst)
}
}
struct Embedding<'a> {
vocab_size: usize,
hidden_size: usize,
@ -533,7 +585,7 @@ impl<'a> Map1 for Embedding<'a> {
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
vocab_size: self.vocab_size,
src_size: self.vocab_size,
op: "take",
}
.bt())?
@ -1330,6 +1382,11 @@ impl BackendStorage for CpuStorage {
.map(rhs, rhs_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let ids = ids.as_slice::<u32>()?;
IndexSelect { ids, ids_l, dim }.map(self, l)
}
fn matmul(
&self,
rhs: &Self,