Start refactoring the stride.

This commit is contained in:
laurent
2023-06-28 12:57:30 +01:00
parent d461d9d751
commit c1bbbf94f6
5 changed files with 124 additions and 108 deletions

View File

@ -1,4 +1,4 @@
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -53,33 +53,27 @@ impl Storage {
}
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(shape, stride, s)?;
let storage = storage.sum(layout, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(shape, stride, s)?;
let storage = storage.sum(layout, s)?;
Ok(Self::Cuda(storage))
}
}
@ -93,32 +87,28 @@ impl Storage {
Ok(())
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn unary_impl<B: op::UnaryOp>(
&self,
shape: &Shape,
stride: &[usize],
) -> Result<Self> {
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
}
@ -127,19 +117,18 @@ impl Storage {
pub(crate) fn binary_impl<B: op::BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
lhs_layout: &Layout,
rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, B::NAME)?;
self.same_dtype(rhs, B::NAME)?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {
@ -156,23 +145,22 @@ impl Storage {
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
layout: &Shape,
t: &Self,
stride_t: &[usize],
layout_t: &Layout,
f: &Self,
stride_f: &[usize],
layout_f: &Layout,
) -> Result<Self> {
self.same_device(t, "where")?;
self.same_device(f, "where")?;
t.same_dtype(f, "where")?;
match (self, t, f) {
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
@ -185,8 +173,7 @@ impl Storage {
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
layout: &Layout,
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
@ -194,11 +181,11 @@ impl Storage {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {