mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
This commit is contained in:
@ -945,6 +945,57 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::ScatterAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
@ -992,6 +1043,40 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
let self_dims = self.dims();
|
||||
let indexes_dims = indexes.dims();
|
||||
let mismatch = if indexes_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::Gather(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
|
Reference in New Issue
Block a user