mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
This commit is contained in:
@ -40,6 +40,16 @@ pub trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
|
@ -39,6 +39,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
@ -56,6 +57,7 @@ impl Tensor {
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
@ -162,6 +164,11 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
|
||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -628,6 +628,59 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a> {
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Gather<'a> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let src_dims = src_l.dims();
|
||||
let dst_len: usize = ids_dims.iter().product();
|
||||
let dst_left_len: usize = ids_dims[..dim].iter().product();
|
||||
let dst_dim_len = ids_dims[dim];
|
||||
let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
|
||||
|
||||
let src_dim_len = src_dims[dim];
|
||||
let src_right_len: usize = src_dims[dim + 1..].iter().product();
|
||||
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
for left_i in 0..dst_left_len {
|
||||
let start_src_idx = left_i * src_right_len * src_dim_len;
|
||||
let start_dst_idx = left_i * dst_right_len * dst_dim_len;
|
||||
for i in 0..dst_dim_len {
|
||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let dst_idx = start_dst_idx + right_i;
|
||||
let index = ids[dst_idx] as usize;
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim_len,
|
||||
op: "gather",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a> {
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
@ -680,6 +733,63 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for ScatterAdd<'a> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
let ids_left_len: usize = ids_dims[..dim].iter().product();
|
||||
let ids_dim_len = ids_dims[dim];
|
||||
let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
let start_dst_idx = left_i * dst_right_len * dst_dim_len;
|
||||
for i in 0..ids_dim_len {
|
||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let ids_idx = start_ids_idx + right_i;
|
||||
let index = ids[ids_idx] as usize;
|
||||
if index >= dst_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: dst_dim_len,
|
||||
op: "gather",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||
dst[dst_idx] += src[ids_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
dim: usize,
|
||||
@ -1593,6 +1703,24 @@ impl BackendStorage for CpuStorage {
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
}
|
||||
|
||||
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
Gather { ids, ids_l, dim }.map(self, l)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
@ -1064,6 +1064,20 @@ impl BackendStorage for CudaStorage {
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement gather").into())
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement scatter-add").into())
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
|
@ -85,6 +85,22 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
|
@ -66,6 +66,8 @@ pub(crate) enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
@ -325,6 +325,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn gather(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "index-add")?;
|
||||
match (self, indexes) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
@ -945,6 +945,57 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::ScatterAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
@ -992,6 +1043,40 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
let self_dims = self.dims();
|
||||
let indexes_dims = indexes.dims();
|
||||
let mismatch = if indexes_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::Gather(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
|
@ -17,10 +17,11 @@ fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> candle::Result<Tenso
|
||||
Ok(log_sm)
|
||||
}
|
||||
|
||||
// TODO: Once the index_select backprop is efficient enough, switch to using this.
|
||||
fn _nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result<Tensor> {
|
||||
let b_sz = target.shape().r1()?;
|
||||
inp.index_select(target, 0)?.sum_all()? / b_sz as f64
|
||||
fn nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result<Tensor> {
|
||||
let b_sz = target.dim(0)?;
|
||||
inp.gather(target, 1)?
|
||||
.sum_all()?
|
||||
.affine(-1f64 / b_sz as f64, 0.)
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
@ -32,12 +33,7 @@ pub fn main() -> Result<()> {
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images;
|
||||
let train_labels = train_labels.to_vec1::<u8>()?;
|
||||
let train_label_mask = train_labels
|
||||
.iter()
|
||||
.flat_map(|l| (0..LABELS).map(|i| f32::from(i == *l as usize)))
|
||||
.collect::<Vec<_>>();
|
||||
let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
|
||||
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
|
||||
@ -46,9 +42,7 @@ pub fn main() -> Result<()> {
|
||||
for epoch in 1..200 {
|
||||
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
let log_sm = log_softmax(&logits, D::Minus1)?;
|
||||
let loss = (&log_sm * &train_label_mask)?
|
||||
.sum_all()?
|
||||
.affine(-1f64 / train_images.dim(0)? as f64, 0f64)?;
|
||||
let loss = nll_loss(&log_sm, &train_labels)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
@ -63,7 +57,7 @@ pub fn main() -> Result<()> {
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
100. * test_accuracy
|
||||
)
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user