Starting to fix some tests.

This commit is contained in:
Nicolas Patry
2023-11-11 01:02:15 +01:00
parent 4f39695465
commit 3ad02147e4
3 changed files with 42 additions and 15 deletions

View File

@ -112,13 +112,7 @@ macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
#[derive(Clone, Copy)]
pub struct Kernel(pub(crate) &'static str);
impl std::fmt::Display for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
$(
pub mod $name {
use super::Kernel;
@ -127,16 +121,17 @@ macro_rules! ops{
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
}
)+
pub mod copy {
use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float");
pub const HALF: Kernel = Kernel("copy_half");
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
pub const U32: Kernel = Kernel("copy_u32");
}
}
pub mod strided {
#[derive(Clone, Copy)]
pub struct Kernel(pub(crate) &'static str);
impl std::fmt::Display for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
$(
pub mod $name {
use super::Kernel;
@ -145,12 +140,19 @@ macro_rules! ops{
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
}
)+
pub mod copy {
use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float_strided");
pub const HALF: Kernel = Kernel("copy_half_strided");
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
pub const U32: Kernel = Kernel("copy_u32_strided");
}
}
};
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round);
}
pub mod binary {
ops!(add, sub, mul, div);