mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add. * More cpu support for index-add. * Use index-add in the backprop.
This commit is contained in:
@ -164,26 +164,8 @@ impl Tensor {
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// TODO: This is very very very inefficient, have some dedicated kernel for this.
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html
|
||||
let indexes = indexes.to_vec1::<u32>()?;
|
||||
for (dst_index, src_index) in indexes.iter().enumerate() {
|
||||
let src_index = *src_index as usize;
|
||||
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
|
||||
let mut pre_dims = arg.dims().to_vec();
|
||||
pre_dims[dim] = src_index;
|
||||
let pre_zeros =
|
||||
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||
let mut post_dims = arg.dims().to_vec();
|
||||
post_dims[dim] = post_dims[dim] - src_index - 1;
|
||||
let post_zeros =
|
||||
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||
let src_grad =
|
||||
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
|
||||
*sum_grad = sum_grad.add(&src_grad)?;
|
||||
}
|
||||
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||
|
@ -665,7 +665,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
src_size: src_dim,
|
||||
size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
@ -680,6 +680,72 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for IndexAdd<'a> {
|
||||
const OP: &'static str = "index-add";
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
|
||||
// v1, l1 -> self
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
let max_idx = l1.dims()[dim];
|
||||
let stride = src_l.stride()[dim];
|
||||
if dim == 0 {
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
op: "index-add",
|
||||
size: max_idx,
|
||||
})?
|
||||
}
|
||||
let src_idx = src_idx * stride;
|
||||
let dst_idx = dst_idx * stride;
|
||||
let src = &src[src_idx..src_idx + stride];
|
||||
let dst = &mut dst[dst_idx..dst_idx + stride];
|
||||
for (d, &s) in dst.iter_mut().zip(src.iter()) {
|
||||
*d += s
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
|
||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
op: "index-add",
|
||||
size: max_idx,
|
||||
})?
|
||||
}
|
||||
for pre_i in 0..pre_dim {
|
||||
let pre_i = pre_i * stride;
|
||||
let pre_src_i = (pre_i + src_idx) * post_dim;
|
||||
let pre_dst_i = (pre_i + dst_idx) * post_dim;
|
||||
let src = &src[pre_src_i..pre_src_i + post_dim];
|
||||
let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
|
||||
for (d, &s) in dst.iter_mut().zip(src.iter()) {
|
||||
*d += s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
@ -698,7 +764,7 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
src_size: self.vocab_size,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
@ -711,12 +777,7 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
dst_offset: usize,
|
||||
src_l: &Layout,
|
||||
) {
|
||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let to_copy = (dst.len() - dst_offset).min(len);
|
||||
@ -1534,14 +1595,19 @@ impl BackendStorage for CpuStorage {
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
|
@ -112,11 +112,11 @@ pub enum Error {
|
||||
msg: &'static str,
|
||||
},
|
||||
|
||||
#[error("{op} invalid index {index} with src dim size {src_size}")]
|
||||
#[error("{op} invalid index {index} with dim size {size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
src_size: usize,
|
||||
size: usize,
|
||||
},
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
|
@ -947,6 +947,30 @@ impl Tensor {
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
// the target tensor self)
|
||||
mismatch || source_dims[dim] != indexes.shape().r1()?
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "index-add",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
|
Reference in New Issue
Block a user