mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the index-select op. (#209)
* Add the index-select op. * Cpu implementation of index-select. * Add the cpu implementation for index-select.
This commit is contained in:
@ -40,6 +40,7 @@ pub(crate) trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
|
@ -40,6 +40,7 @@ impl Tensor {
|
||||
..
|
||||
}
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -143,9 +144,12 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexSelect(_lhs, _rhs, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "index-select" })?
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
@ -195,10 +199,10 @@ impl Tensor {
|
||||
}
|
||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "max" })
|
||||
Err(Error::BackwardNotSupported { op: "max" })?
|
||||
}
|
||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "min" })
|
||||
Err(Error::BackwardNotSupported { op: "min" })?
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -221,9 +225,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Abs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "abs" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
@ -258,21 +260,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Softmax(_arg, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||
}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "gelu" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Relu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "relu" })
|
||||
}
|
||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -515,6 +515,58 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a> {
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
[n_ids] => *n_ids,
|
||||
d => Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: d.len(),
|
||||
shape: self.ids_l.shape().clone(),
|
||||
})?,
|
||||
};
|
||||
let stride_ids = self.ids_l.stride()[0];
|
||||
let mut dst_dims = layout.dims().to_vec();
|
||||
let src_dim = dst_dims[dim];
|
||||
dst_dims[dim] = n_ids;
|
||||
let dst_len: usize = dst_dims.iter().product();
|
||||
let left_len: usize = dst_dims[..dim].iter().product();
|
||||
let right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
for left_i in 0..left_len {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
src_size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
let start_dst_idx = start_dst_idx + i * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
@ -533,7 +585,7 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size: self.vocab_size,
|
||||
src_size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
@ -1330,6 +1382,11 @@ impl BackendStorage for CpuStorage {
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -1059,6 +1059,10 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -82,6 +82,9 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
|
@ -109,11 +109,11 @@ pub enum Error {
|
||||
msg: &'static str,
|
||||
},
|
||||
|
||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||
#[error("{op} invalid index {index} with src dim size {src_size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
vocab_size: usize,
|
||||
src_size: usize,
|
||||
},
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
|
@ -51,6 +51,7 @@ pub(crate) enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
@ -267,6 +267,32 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_select(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "index-select")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "index-select",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -960,6 +960,33 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
[l] => *l,
|
||||
_ => Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
let storage = self.storage().index_select(
|
||||
&indexes.storage(),
|
||||
self.layout(),
|
||||
indexes.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = indexes_len;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
/// Returns an iterator over position of the elements in the storage when ranging over the
|
||||
/// index tuples in lexicographic order.
|
||||
pub fn strided_index(&self) -> crate::StridedIndex {
|
||||
|
@ -301,6 +301,39 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() -> Result<()> {
|
||||
// TODO: Test on cuda once the kernel is available.
|
||||
let device = &Device::Cpu;
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
Reference in New Issue
Block a user