Start adding index-add.

This commit is contained in:
laurent
2023-07-21 20:12:48 +01:00
parent 5cc843550d
commit 27174a82aa
8 changed files with 97 additions and 3 deletions

View File

@ -308,7 +308,7 @@ impl Storage {
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
@ -325,6 +325,30 @@ impl Storage {
}
}
pub(crate) fn index_add(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
self.same_device(source, "index-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn index_select(
&self,
rhs: &Self,
@ -334,7 +358,7 @@ impl Storage {
) -> Result<Self> {
self.same_device(rhs, "index-select")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cpu(storage))
}