Add the index-select op. (#209)

* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
This commit is contained in:
Laurent Mazare
2023-07-20 15:01:03 +02:00
committed by GitHub
parent 2a8f28d687
commit fa08fb3126
10 changed files with 168 additions and 20 deletions

View File

@ -40,6 +40,7 @@ pub(crate) trait BackendStorage: Sized {
) -> Result<Self>;
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn matmul(
&self,

View File

@ -40,6 +40,7 @@ impl Tensor {
..
}
| Op::Binary(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@ -143,9 +144,12 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexSelect(_lhs, _rhs, _) => {
Err(Error::BackwardNotSupported { op: "index-select" })?
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
Err(Error::BackwardNotSupported { op: "embedding" })?
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
@ -195,10 +199,10 @@ impl Tensor {
}
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
Op::Reduce(_args, ReduceOp::Max, _) => {
return Err(Error::BackwardNotSupported { op: "max" })
Err(Error::BackwardNotSupported { op: "max" })?
}
Op::Reduce(_args, ReduceOp::Min, _) => {
return Err(Error::BackwardNotSupported { op: "min" })
Err(Error::BackwardNotSupported { op: "min" })?
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -221,9 +225,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(_, UnaryOp::Abs) => {
return Err(Error::BackwardNotSupported { op: "abs" })
}
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
@ -258,21 +260,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => {
return Err(Error::BackwardNotSupported { op: "gelu" })
}
Op::Unary(_, UnaryOp::Relu) => {
return Err(Error::BackwardNotSupported { op: "relu" })
}
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -515,6 +515,58 @@ impl Map1 for Affine {
}
}
struct IndexSelect<'a> {
ids: &'a [u32],
ids_l: &'a Layout,
dim: usize,
}
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
[n_ids] => *n_ids,
d => Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: d.len(),
shape: self.ids_l.shape().clone(),
})?,
};
let stride_ids = self.ids_l.stride()[0];
let mut dst_dims = layout.dims().to_vec();
let src_dim = dst_dims[dim];
dst_dims[dim] = n_ids;
let dst_len: usize = dst_dims.iter().product();
let left_len: usize = dst_dims[..dim].iter().product();
let right_len: usize = dst_dims[dim + 1..].iter().product();
let mut dst = vec![T::zero(); dst_len];
for left_i in 0..left_len {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
if index >= src_dim {
Err(Error::InvalidIndex {
index,
src_size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
let start_dst_idx = start_dst_idx + i * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
Ok(dst)
}
}
struct Embedding<'a> {
vocab_size: usize,
hidden_size: usize,
@ -533,7 +585,7 @@ impl<'a> Map1 for Embedding<'a> {
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
vocab_size: self.vocab_size,
src_size: self.vocab_size,
op: "take",
}
.bt())?
@ -1330,6 +1382,11 @@ impl BackendStorage for CpuStorage {
.map(rhs, rhs_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let ids = ids.as_slice::<u32>()?;
IndexSelect { ids, ids_l, dim }.map(self, l)
}
fn matmul(
&self,
rhs: &Self,

View File

@ -1059,6 +1059,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
}
fn matmul(
&self,
rhs: &Self,

View File

@ -82,6 +82,9 @@ impl crate::backend::BackendStorage for CudaStorage {
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn matmul(
&self,

View File

@ -109,11 +109,11 @@ pub enum Error {
msg: &'static str,
},
#[error("{op} invalid index {index} with vocab {vocab_size}")]
#[error("{op} invalid index {index} with src dim size {src_size}")]
InvalidIndex {
op: &'static str,
index: usize,
vocab_size: usize,
src_size: usize,
},
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]

View File

@ -51,6 +51,7 @@ pub(crate) enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
IndexSelect(Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]

View File

@ -267,6 +267,32 @@ impl Storage {
}
}
pub(crate) fn index_select(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(rhs, "index-select")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "index-select",
}
.bt()),
}
}
pub(crate) fn matmul(
&self,
rhs: &Self,

View File

@ -960,6 +960,33 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {
[l] => *l,
_ => Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
op: "index-select",
}
.bt())?,
};
let storage = self.storage().index_select(
&indexes.storage(),
self.layout(),
indexes.layout(),
dim,
)?;
let mut dims = self.dims().to_vec();
dims[dim] = indexes_len;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, dims, op, false))
}
/// Returns an iterator over position of the elements in the storage when ranging over the
/// index tuples in lexicographic order.
pub fn strided_index(&self) -> crate::StridedIndex {

View File

@ -301,6 +301,39 @@ fn embeddings(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn index_select() -> Result<()> {
// TODO: Test on cuda once the kernel is available.
let device = &Device::Cpu;
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
Ok(())
}
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;