Start adding index-add.

This commit is contained in:
laurent
2023-07-21 20:12:48 +01:00
parent 5cc843550d
commit 27174a82aa
8 changed files with 97 additions and 3 deletions

View File

@ -41,6 +41,15 @@ pub trait BackendStorage: Sized {
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
fn matmul(
&self,

View File

@ -38,7 +38,9 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
Op::IndexAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@ -160,6 +162,7 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?;

View File

@ -1532,6 +1532,18 @@ impl BackendStorage for CpuStorage {
IndexSelect { ids, ids_l, dim }.map(self, l)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn matmul(
&self,
rhs: &Self,

View File

@ -1064,6 +1064,17 @@ impl BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-add").into())
}
fn matmul(
&self,

View File

@ -85,6 +85,17 @@ impl crate::backend::BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn matmul(
&self,

View File

@ -67,6 +67,7 @@ pub(crate) enum Op {
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]

View File

@ -308,7 +308,7 @@ impl Storage {
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
@ -325,6 +325,30 @@ impl Storage {
}
}
pub(crate) fn index_add(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
self.same_device(source, "index-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn index_select(
&self,
rhs: &Self,
@ -334,7 +358,7 @@ impl Storage {
) -> Result<Self> {
self.same_device(rhs, "index-select")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cpu(storage))
}

View File

@ -945,6 +945,29 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let storage = self.storage().index_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexAdd(
self.clone(),
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {