Merge branch 'main' into ivarflakstad/metal-prng

This commit is contained in:
Ivar Flakstad
2024-01-07 11:52:03 +01:00
56 changed files with 2068 additions and 719 deletions

View File

@ -134,6 +134,9 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
}
)+
pub mod copy {
@ -141,6 +144,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32");
pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bf16");
pub const I64: Kernel = Kernel("copy_i64");
pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8");
}
@ -154,6 +158,9 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
}
)+
pub mod copy {
@ -161,6 +168,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
pub const I64: Kernel = Kernel("copy_i64_strided");
pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided");
}
@ -169,7 +177,10 @@ macro_rules! ops{
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
recip
);
}
pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);