mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Check that the tensor is contiguous before applying the kernel.
This commit is contained in:
@ -85,13 +85,15 @@ impl CudaStorage {
|
|||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
_stride: &[usize],
|
stride: &[usize],
|
||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::F32(arg) => {
|
Self::F32(arg) => {
|
||||||
// TODO: Handle the stride.
|
if !shape.is_contiguous(stride) {
|
||||||
|
todo!("affine is only implemented for the contiguous case")
|
||||||
|
}
|
||||||
let dev = arg.device();
|
let dev = arg.device();
|
||||||
let module_name = "affine_f32";
|
let module_name = "affine_f32";
|
||||||
if !dev.has_func(module_name, module_name) {
|
if !dev.has_func(module_name, module_name) {
|
||||||
|
14
src/shape.rs
14
src/shape.rs
@ -128,6 +128,20 @@ impl Shape {
|
|||||||
stride.reverse();
|
stride.reverse();
|
||||||
stride
|
stride
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_contiguous(&self, stride: &[usize]) -> bool {
|
||||||
|
if self.0.len() != stride.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
let mut acc = 1;
|
||||||
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
|
if stride != acc {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
acc *= dim;
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -310,14 +310,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_contiguous(&self) -> bool {
|
pub fn is_contiguous(&self) -> bool {
|
||||||
let mut acc = 1;
|
self.shape.is_contiguous(&self.stride)
|
||||||
for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
|
|
||||||
if stride != acc {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
acc *= dim;
|
|
||||||
}
|
|
||||||
true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
|
Reference in New Issue
Block a user