mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Put back affine strided tests
This commit is contained in:
@ -295,7 +295,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn _run_affine_strided<T: Clone>(
|
||||
fn run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
@ -314,7 +314,7 @@ fn _run_affine_strided<T: Clone>(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
"affine_float_strided",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
@ -327,7 +327,8 @@ fn _run_affine_strided<T: Clone>(
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
let len: usize = shape.iter().product();
|
||||
output.read_to_vec::<T>(len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -345,15 +346,17 @@ fn affine() {
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn affine_strided() {
|
||||
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
// let mul = 1.5;
|
||||
// let add = 1.1;
|
||||
// let result = run_affine_(&input, mul, add);
|
||||
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||
|
||||
// }
|
||||
#[test]
|
||||
fn affine_strided() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let shape = [4];
|
||||
let strides = [2];
|
||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||
// 1 on 2
|
||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() {
|
||||
|
Reference in New Issue
Block a user