Rework the embeddings so that it works on non-contiguous weights + factor out some code.

This commit is contained in:
laurent
2023-06-25 17:37:47 +01:00
parent 334524e2c4
commit 817e4b5005
6 changed files with 66 additions and 48 deletions

View File

@ -156,19 +156,20 @@ impl Storage {
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
self.same_device(rhs, "embedding")?;
self.same_dtype(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {