Propagate the layout refactoring.

This commit is contained in:
laurent
2023-06-28 13:42:23 +01:00
parent 30b355ccd2
commit 303b853098
5 changed files with 130 additions and 129 deletions

View File

@ -145,7 +145,7 @@ impl Storage {
pub(crate) fn where_cond(
&self,
layout: &Shape,
layout: &Layout,
t: &Self,
layout_t: &Layout,
f: &Self,
@ -171,7 +171,7 @@ impl Storage {
}
}
pub(crate) fn embedding_impl(
pub(crate) fn embedding(
&self,
layout: &Layout,
rhs: &Self,
@ -181,11 +181,11 @@ impl Storage {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
@ -227,15 +227,11 @@ impl Storage {
&self,
dst: &mut Self,
dst_offset: usize,
src_layout: &Layout,
src_l: &Layout,
) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => {
src.copy_strided_src(dst, dst_offset, src_layout, src_offset)
}
(Self::Cuda(src), Self::Cuda(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?)
}
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),