Check that the tensor is contiguous before applying the kernel.

This commit is contained in:
laurent
2023-06-21 21:28:59 +01:00
parent 9834151254
commit 7c46de9584
3 changed files with 19 additions and 10 deletions

View File

@ -85,13 +85,15 @@ impl CudaStorage {
pub(crate) fn affine_impl(
&self,
shape: &Shape,
_stride: &[usize],
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::F32(arg) => {
// TODO: Handle the stride.
if !shape.is_contiguous(stride) {
todo!("affine is only implemented for the contiguous case")
}
let dev = arg.device();
let module_name = "affine_f32";
if !dev.has_func(module_name, module_name) {

View File

@ -128,6 +128,20 @@ impl Shape {
stride.reverse();
stride
}
pub fn is_contiguous(&self, stride: &[usize]) -> bool {
if self.0.len() != stride.len() {
return false;
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if stride != acc {
return false;
}
acc *= dim;
}
true
}
}
#[cfg(test)]

View File

@ -310,14 +310,7 @@ impl Tensor {
}
pub fn is_contiguous(&self) -> bool {
let mut acc = 1;
for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
if stride != acc {
return false;
}
acc *= dim;
}
true
self.shape.is_contiguous(&self.stride)
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first