Adapt the cuda bits.

This commit is contained in:
laurent
2023-06-28 15:43:03 +01:00
parent cca699be6c
commit 3f0d9fbb25
5 changed files with 109 additions and 87 deletions

View File

@ -167,26 +167,20 @@ impl Storage {
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
op: "where",
}),
}
}
pub(crate) fn embedding(
&self,
layout: &Layout,
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {