Add the gather op. (#219)

* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 08:21:28 +02:00
committed by GitHub
parent 6eeea1b04e
commit 52c5d8c087
9 changed files with 315 additions and 14 deletions

View File

@ -945,6 +945,57 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
true
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
}
mismatch
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
let storage = self.storage().scatter_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::ScatterAdd(
self.clone(),
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
@ -992,6 +1043,40 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
true
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
}
mismatch
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "gather",
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
})?
}
let storage =
self.storage()
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::Gather(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, indexes.shape(), op, false))
}
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {