mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the index-select op. (#209)
* Add the index-select op. * Cpu implementation of index-select. * Add the cpu implementation for index-select.
This commit is contained in:
@ -40,6 +40,7 @@ pub(crate) trait BackendStorage: Sized {
|
|||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -40,6 +40,7 @@ impl Tensor {
|
|||||||
..
|
..
|
||||||
}
|
}
|
||||||
| Op::Binary(lhs, rhs, _)
|
| Op::Binary(lhs, rhs, _)
|
||||||
|
| Op::IndexSelect(lhs, rhs, _)
|
||||||
| Op::Embedding(lhs, rhs)
|
| Op::Embedding(lhs, rhs)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
@ -143,9 +144,12 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
|
Op::IndexSelect(_lhs, _rhs, _) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "index-select" })?
|
||||||
|
}
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||||
}
|
}
|
||||||
Op::Matmul(lhs, rhs) => {
|
Op::Matmul(lhs, rhs) => {
|
||||||
// Skipping checks, the op went ok, we can skip
|
// Skipping checks, the op went ok, we can skip
|
||||||
@ -195,10 +199,10 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "max" })
|
Err(Error::BackwardNotSupported { op: "max" })?
|
||||||
}
|
}
|
||||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "min" })
|
Err(Error::BackwardNotSupported { op: "min" })?
|
||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -221,9 +225,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Abs) => {
|
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
||||||
return Err(Error::BackwardNotSupported { op: "abs" })
|
|
||||||
}
|
|
||||||
Op::Unary(arg, UnaryOp::Exp) => {
|
Op::Unary(arg, UnaryOp::Exp) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||||
@ -258,21 +260,15 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Softmax(_arg, _) => {
|
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
|
||||||
}
|
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Gelu) => {
|
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||||
return Err(Error::BackwardNotSupported { op: "gelu" })
|
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||||
}
|
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||||
Op::Unary(_, UnaryOp::Relu) => {
|
|
||||||
return Err(Error::BackwardNotSupported { op: "relu" })
|
|
||||||
}
|
|
||||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
|
||||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -515,6 +515,58 @@ impl Map1 for Affine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IndexSelect<'a> {
|
||||||
|
ids: &'a [u32],
|
||||||
|
ids_l: &'a Layout,
|
||||||
|
dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Map1 for IndexSelect<'a> {
|
||||||
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
Some((a, b)) => &src[a..b],
|
||||||
|
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||||
|
};
|
||||||
|
let dim = self.dim;
|
||||||
|
let n_ids = match self.ids_l.dims() {
|
||||||
|
[n_ids] => *n_ids,
|
||||||
|
d => Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 1,
|
||||||
|
got: d.len(),
|
||||||
|
shape: self.ids_l.shape().clone(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let stride_ids = self.ids_l.stride()[0];
|
||||||
|
let mut dst_dims = layout.dims().to_vec();
|
||||||
|
let src_dim = dst_dims[dim];
|
||||||
|
dst_dims[dim] = n_ids;
|
||||||
|
let dst_len: usize = dst_dims.iter().product();
|
||||||
|
let left_len: usize = dst_dims[..dim].iter().product();
|
||||||
|
let right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||||
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
|
for left_i in 0..left_len {
|
||||||
|
let start_src_idx = left_i * right_len * src_dim;
|
||||||
|
let start_dst_idx = left_i * right_len * n_ids;
|
||||||
|
for i in 0..n_ids {
|
||||||
|
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
|
||||||
|
if index >= src_dim {
|
||||||
|
Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
src_size: src_dim,
|
||||||
|
op: "index-select",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
let start_src_idx = start_src_idx + index * right_len;
|
||||||
|
let start_dst_idx = start_dst_idx + i * right_len;
|
||||||
|
dst[start_dst_idx..start_dst_idx + right_len]
|
||||||
|
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Embedding<'a> {
|
struct Embedding<'a> {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
@ -533,7 +585,7 @@ impl<'a> Map1 for Embedding<'a> {
|
|||||||
if index >= self.vocab_size {
|
if index >= self.vocab_size {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
vocab_size: self.vocab_size,
|
src_size: self.vocab_size,
|
||||||
op: "take",
|
op: "take",
|
||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
@ -1330,6 +1382,11 @@ impl BackendStorage for CpuStorage {
|
|||||||
.map(rhs, rhs_l)
|
.map(rhs, rhs_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
|
let ids = ids.as_slice::<u32>()?;
|
||||||
|
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -1059,6 +1059,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -82,6 +82,9 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -109,11 +109,11 @@ pub enum Error {
|
|||||||
msg: &'static str,
|
msg: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
#[error("{op} invalid index {index} with src dim size {src_size}")]
|
||||||
InvalidIndex {
|
InvalidIndex {
|
||||||
op: &'static str,
|
op: &'static str,
|
||||||
index: usize,
|
index: usize,
|
||||||
vocab_size: usize,
|
src_size: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
|
@ -51,6 +51,7 @@ pub(crate) enum Op {
|
|||||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
|
IndexSelect(Tensor, Tensor, usize),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
@ -267,6 +267,32 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn index_select(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
d: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "index-select")?;
|
||||||
|
match (self, rhs) {
|
||||||
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "index-select",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul(
|
pub(crate) fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -960,6 +960,33 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||||
|
let indexes_len = match indexes.dims() {
|
||||||
|
[l] => *l,
|
||||||
|
_ => Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: indexes.shape().clone(),
|
||||||
|
op: "index-select",
|
||||||
|
}
|
||||||
|
.bt())?,
|
||||||
|
};
|
||||||
|
let storage = self.storage().index_select(
|
||||||
|
&indexes.storage(),
|
||||||
|
self.layout(),
|
||||||
|
indexes.layout(),
|
||||||
|
dim,
|
||||||
|
)?;
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
dims[dim] = indexes_len;
|
||||||
|
let op = if indexes.track_op() || self.track_op() {
|
||||||
|
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns an iterator over position of the elements in the storage when ranging over the
|
/// Returns an iterator over position of the elements in the storage when ranging over the
|
||||||
/// index tuples in lexicographic order.
|
/// index tuples in lexicographic order.
|
||||||
pub fn strided_index(&self) -> crate::StridedIndex {
|
pub fn strided_index(&self) -> crate::StridedIndex {
|
||||||
|
@ -301,6 +301,39 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select() -> Result<()> {
|
||||||
|
// TODO: Test on cuda once the kernel is available.
|
||||||
|
let device = &Device::Cpu;
|
||||||
|
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||||
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let hs = t.index_select(&ids, 1)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 2.0, 1.0],
|
||||||
|
[3.0, 5.0, 4.0],
|
||||||
|
[6.0, 8.0, 7.0],
|
||||||
|
[9.0, 11.0, 10.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let hs = t.index_select(&ids, 0)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(device: &Device) -> Result<()> {
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
Reference in New Issue
Block a user