mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the index-select op. (#209)
* Add the index-select op. * Cpu implementation of index-select. * Add the cpu implementation for index-select.
This commit is contained in:
@ -515,6 +515,58 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a> {
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
[n_ids] => *n_ids,
|
||||
d => Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: d.len(),
|
||||
shape: self.ids_l.shape().clone(),
|
||||
})?,
|
||||
};
|
||||
let stride_ids = self.ids_l.stride()[0];
|
||||
let mut dst_dims = layout.dims().to_vec();
|
||||
let src_dim = dst_dims[dim];
|
||||
dst_dims[dim] = n_ids;
|
||||
let dst_len: usize = dst_dims.iter().product();
|
||||
let left_len: usize = dst_dims[..dim].iter().product();
|
||||
let right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
for left_i in 0..left_len {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
src_size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
let start_dst_idx = start_dst_idx + i * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
@ -533,7 +585,7 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size: self.vocab_size,
|
||||
src_size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
@ -1330,6 +1382,11 @@ impl BackendStorage for CpuStorage {
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
Reference in New Issue
Block a user