mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Start adding index-add.
This commit is contained in:
@ -41,6 +41,15 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
|
@ -38,7 +38,9 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
@ -160,6 +162,7 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -1532,6 +1532,18 @@ impl BackendStorage for CpuStorage {
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -1064,6 +1064,17 @@ impl BackendStorage for CudaStorage {
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement index-add").into())
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
|
@ -85,6 +85,17 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
|
@ -67,6 +67,7 @@ pub(crate) enum Op {
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
@ -308,7 +308,7 @@ impl Storage {
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
@ -325,6 +325,30 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "index-add")?;
|
||||
self.same_device(source, "index-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_select(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
@ -334,7 +358,7 @@ impl Storage {
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "index-select")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
|
@ -945,6 +945,29 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
|
Reference in New Issue
Block a user