mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Check that the tensor is contiguous before applying the kernel.
This commit is contained in:
@ -85,13 +85,15 @@ impl CudaStorage {
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
_stride: &[usize],
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
// TODO: Handle the stride.
|
||||
if !shape.is_contiguous(stride) {
|
||||
todo!("affine is only implemented for the contiguous case")
|
||||
}
|
||||
let dev = arg.device();
|
||||
let module_name = "affine_f32";
|
||||
if !dev.has_func(module_name, module_name) {
|
||||
|
14
src/shape.rs
14
src/shape.rs
@ -128,6 +128,20 @@ impl Shape {
|
||||
stride.reverse();
|
||||
stride
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self, stride: &[usize]) -> bool {
|
||||
if self.0.len() != stride.len() {
|
||||
return false;
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||
if stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -310,14 +310,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
|
||||
if stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
}
|
||||
true
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
|
Reference in New Issue
Block a user