Polish the index-add op and use it in the index-select backprop (#218)

* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 06:31:46 +02:00
committed by GitHub
parent 27174a82aa
commit 6eeea1b04e
4 changed files with 108 additions and 36 deletions

View File

@ -164,26 +164,8 @@ impl Tensor {
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => { Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
// TODO: This is very very very inefficient, have some dedicated kernel for this. *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html
let indexes = indexes.to_vec1::<u32>()?;
for (dst_index, src_index) in indexes.iter().enumerate() {
let src_index = *src_index as usize;
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
let mut pre_dims = arg.dims().to_vec();
pre_dims[dim] = src_index;
let pre_zeros =
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
let mut post_dims = arg.dims().to_vec();
post_dims[dim] = post_dims[dim] - src_index - 1;
let post_zeros =
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
let src_grad =
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
*sum_grad = sum_grad.add(&src_grad)?;
}
} }
Op::Embedding(_lhs, _rhs) => { Op::Embedding(_lhs, _rhs) => {
Err(Error::BackwardNotSupported { op: "embedding" })? Err(Error::BackwardNotSupported { op: "embedding" })?

View File

@ -665,7 +665,7 @@ impl<'a> Map1 for IndexSelect<'a> {
if index >= src_dim { if index >= src_dim {
Err(Error::InvalidIndex { Err(Error::InvalidIndex {
index, index,
src_size: src_dim, size: src_dim,
op: "index-select", op: "index-select",
} }
.bt())? .bt())?
@ -680,6 +680,72 @@ impl<'a> Map1 for IndexSelect<'a> {
} }
} }
struct IndexAdd<'a> {
ids: &'a [u32],
dim: usize,
}
impl<'a> Map2 for IndexAdd<'a> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" })?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
let max_idx = l1.dims()[dim];
let stride = src_l.stride()[dim];
if dim == 0 {
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx as usize;
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
op: "index-add",
size: max_idx,
})?
}
let src_idx = src_idx * stride;
let dst_idx = dst_idx * stride;
let src = &src[src_idx..src_idx + stride];
let dst = &mut dst[dst_idx..dst_idx + stride];
for (d, &s) in dst.iter_mut().zip(src.iter()) {
*d += s
}
}
} else {
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx as usize;
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
op: "index-add",
size: max_idx,
})?
}
for pre_i in 0..pre_dim {
let pre_i = pre_i * stride;
let pre_src_i = (pre_i + src_idx) * post_dim;
let pre_dst_i = (pre_i + dst_idx) * post_dim;
let src = &src[pre_src_i..pre_src_i + post_dim];
let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
for (d, &s) in dst.iter_mut().zip(src.iter()) {
*d += s
}
}
}
}
Ok(dst)
}
}
struct Embedding<'a> { struct Embedding<'a> {
vocab_size: usize, vocab_size: usize,
hidden_size: usize, hidden_size: usize,
@ -698,7 +764,7 @@ impl<'a> Map1 for Embedding<'a> {
if index >= self.vocab_size { if index >= self.vocab_size {
Err(Error::InvalidIndex { Err(Error::InvalidIndex {
index, index,
src_size: self.vocab_size, size: self.vocab_size,
op: "take", op: "take",
} }
.bt())? .bt())?
@ -711,12 +777,7 @@ impl<'a> Map1 for Embedding<'a> {
} }
} }
fn copy_strided_src_<T: Copy + std::fmt::Display>( fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
src: &[T],
dst: &mut [T],
dst_offset: usize,
src_l: &Layout,
) {
match src_l.strided_blocks() { match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => { crate::StridedBlocks::SingleBlock { start_offset, len } => {
let to_copy = (dst.len() - dst_offset).min(len); let to_copy = (dst.len() - dst_offset).min(len);
@ -1534,14 +1595,19 @@ impl BackendStorage for CpuStorage {
fn index_add( fn index_add(
&self, &self,
_: &Layout, l: &Layout,
_: &Self, ids: &Self,
_: &Layout, ids_l: &Layout,
_: &Self, src: &Self,
_: &Layout, src_l: &Layout,
_: usize, dim: usize,
) -> Result<Self> { ) -> Result<Self> {
todo!() let ids = ids.as_slice::<u32>()?;
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
} }
fn matmul( fn matmul(

View File

@ -112,11 +112,11 @@ pub enum Error {
msg: &'static str, msg: &'static str,
}, },
#[error("{op} invalid index {index} with src dim size {src_size}")] #[error("{op} invalid index {index} with dim size {size}")]
InvalidIndex { InvalidIndex {
op: &'static str, op: &'static str,
index: usize, index: usize,
src_size: usize, size: usize,
}, },
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]

View File

@ -947,6 +947,30 @@ impl Tensor {
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?; let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
true
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
}
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
// the target tensor self)
mismatch || source_dims[dim] != indexes.shape().r1()?
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "index-add",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
let storage = self.storage().index_add( let storage = self.storage().index_add(
self.layout(), self.layout(),
&indexes.storage(), &indexes.storage(),