mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Start adding index-add.
This commit is contained in:
@ -41,6 +41,15 @@ pub trait BackendStorage: Sized {
|
|||||||
|
|
||||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -38,7 +38,9 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
} else if let Some(op) = node.op() {
|
} else if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
|
Op::IndexAdd(t1, t2, t3, _)
|
||||||
|
| Op::CustomOp3(t1, t2, t3, _)
|
||||||
|
| Op::WhereCond(t1, t2, t3) => {
|
||||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||||
@ -160,6 +162,7 @@ impl Tensor {
|
|||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
|
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||||
Op::IndexSelect(arg, indexes, dim) => {
|
Op::IndexSelect(arg, indexes, dim) => {
|
||||||
let dim = *dim;
|
let dim = *dim;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -1532,6 +1532,18 @@ impl BackendStorage for CpuStorage {
|
|||||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -1064,6 +1064,17 @@ impl BackendStorage for CudaStorage {
|
|||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
Err(CudaError::InternalError("TODO: implement index-select").into())
|
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||||
}
|
}
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(CudaError::InternalError("TODO: implement index-add").into())
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -85,6 +85,17 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -67,6 +67,7 @@ pub(crate) enum Op {
|
|||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
IndexSelect(Tensor, Tensor, usize),
|
IndexSelect(Tensor, Tensor, usize),
|
||||||
|
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
@ -308,7 +308,7 @@ impl Storage {
|
|||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
self.same_device(rhs, "embedding")?;
|
self.same_device(rhs, "embedding")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
@ -325,6 +325,30 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn index_add(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
indexes: &Self,
|
||||||
|
indexes_l: &Layout,
|
||||||
|
source: &Self,
|
||||||
|
source_l: &Layout,
|
||||||
|
d: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(indexes, "index-add")?;
|
||||||
|
self.same_device(source, "index-add")?;
|
||||||
|
match (self, indexes, source) {
|
||||||
|
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn index_select(
|
pub(crate) fn index_select(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -334,7 +358,7 @@ impl Storage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "index-select")?;
|
self.same_device(rhs, "index-select")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
|
@ -945,6 +945,29 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||||
|
let storage = self.storage().index_add(
|
||||||
|
self.layout(),
|
||||||
|
&indexes.storage(),
|
||||||
|
indexes.layout(),
|
||||||
|
&source.storage(),
|
||||||
|
source.layout(),
|
||||||
|
dim,
|
||||||
|
)?;
|
||||||
|
let op = if indexes.track_op() || self.track_op() {
|
||||||
|
Some(Op::IndexAdd(
|
||||||
|
self.clone(),
|
||||||
|
indexes.clone(),
|
||||||
|
source.clone(),
|
||||||
|
dim,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||||
let indexes_len = match indexes.dims() {
|
let indexes_len = match indexes.dims() {
|
||||||
|
Reference in New Issue
Block a user