Tmp state.

This commit is contained in:
Nicolas Patry
2023-11-10 15:35:46 +01:00
committed by Nicolas Patry
parent f710fab02e
commit d46670f7c0
14 changed files with 699 additions and 63 deletions

View File

@ -112,7 +112,13 @@ macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
#[derive(Clone, Copy)]
pub struct Kernel(pub(crate) &'static str);
impl std::fmt::Display for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
$(
pub mod $name {
use super::Kernel;
@ -124,7 +130,13 @@ macro_rules! ops{
}
pub mod strided {
#[derive(Clone, Copy)]
pub struct Kernel(pub(crate) &'static str);
impl std::fmt::Display for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
$(
pub mod $name {
use super::Kernel;
@ -859,6 +871,30 @@ mod tests {
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_strided_random() {
let v: Vec<_> = (0..10_000).map(|i| rand::random::<f32>()).collect();
let shape = vec![5_000, 2];
let strides = vec![1, 5_000];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
assert_eq!(
approx(vec![results[1]], 4),
approx(vec![expected[5_000]], 4)
);
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
assert_eq!(
approx(vec![results[3]], 4),
approx(vec![expected[5_001]], 4)
);
assert_eq!(
approx(vec![results[5_000]], 4),
approx(vec![expected[2_500]], 4)
);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];