mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Put back affine strided tests
This commit is contained in:
@ -295,7 +295,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
output.read_to_vec::<T>(v.len())
|
output.read_to_vec::<T>(v.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _run_affine_strided<T: Clone>(
|
fn run_affine_strided<T: Clone>(
|
||||||
v: &[T],
|
v: &[T],
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
@ -314,7 +314,7 @@ fn _run_affine_strided<T: Clone>(
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"affine_float",
|
"affine_float_strided",
|
||||||
shape,
|
shape,
|
||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
@ -327,7 +327,8 @@ fn _run_affine_strided<T: Clone>(
|
|||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
output.read_to_vec::<T>(v.len())
|
let len: usize = shape.iter().product();
|
||||||
|
output.read_to_vec::<T>(len)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -345,15 +346,17 @@ fn affine() {
|
|||||||
assert_eq!(result, vec![2.6; 40_000]);
|
assert_eq!(result, vec![2.6; 40_000]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[test]
|
#[test]
|
||||||
// fn affine_strided() {
|
fn affine_strided() {
|
||||||
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
// let mul = 1.5;
|
let mul = 1.5;
|
||||||
// let add = 1.1;
|
let add = 1.1;
|
||||||
// let result = run_affine_(&input, mul, add);
|
let shape = [4];
|
||||||
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
let strides = [2];
|
||||||
|
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||||
// }
|
// 1 on 2
|
||||||
|
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn index_select() {
|
fn index_select() {
|
||||||
|
Reference in New Issue
Block a user