diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8c7f23b3..d12db972 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -85,13 +85,15 @@ impl CudaStorage { pub(crate) fn affine_impl( &self, shape: &Shape, - _stride: &[usize], + stride: &[usize], mul: f64, add: f64, ) -> Result { match self { Self::F32(arg) => { - // TODO: Handle the stride. + if !shape.is_contiguous(stride) { + todo!("affine is only implemented for the contiguous case") + } let dev = arg.device(); let module_name = "affine_f32"; if !dev.has_func(module_name, module_name) { diff --git a/src/shape.rs b/src/shape.rs index d626aee6..ebc497cf 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -128,6 +128,20 @@ impl Shape { stride.reverse(); stride } + + pub fn is_contiguous(&self, stride: &[usize]) -> bool { + if self.0.len() != stride.len() { + return false; + } + let mut acc = 1; + for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { + if stride != acc { + return false; + } + acc *= dim; + } + true + } } #[cfg(test)] diff --git a/src/tensor.rs b/src/tensor.rs index a1262334..02105573 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -310,14 +310,7 @@ impl Tensor { } pub fn is_contiguous(&self) -> bool { - let mut acc = 1; - for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() { - if stride != acc { - return false; - } - acc *= dim; - } - true + self.shape.is_contiguous(&self.stride) } /// Return all the nodes that lead to this value in a topologically sorted vec, the first