mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Starting to fix some tests.
This commit is contained in:
@ -293,6 +293,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||||
|
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
|
// TODO erf does not exist in metal
|
||||||
|
("ugelu_erf", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
@ -519,7 +525,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let out = self.to_cpu_storage().unwrap();
|
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
@ -690,6 +695,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
|
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
@ -112,13 +112,7 @@ macro_rules! ops{
|
|||||||
($($name:ident),+) => {
|
($($name:ident),+) => {
|
||||||
|
|
||||||
pub mod contiguous {
|
pub mod contiguous {
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
pub struct Kernel(pub(crate) &'static str);
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
impl std::fmt::Display for Kernel {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -127,16 +121,17 @@ macro_rules! ops{
|
|||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
|
pub mod copy {
|
||||||
|
use super::Kernel;
|
||||||
|
pub const FLOAT: Kernel = Kernel("copy_float");
|
||||||
|
pub const HALF: Kernel = Kernel("copy_half");
|
||||||
|
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||||
|
pub const U32: Kernel = Kernel("copy_u32");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod strided {
|
pub mod strided {
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
pub struct Kernel(pub(crate) &'static str);
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
impl std::fmt::Display for Kernel {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -145,12 +140,19 @@ macro_rules! ops{
|
|||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
|
pub mod copy {
|
||||||
|
use super::Kernel;
|
||||||
|
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
||||||
|
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||||
|
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||||
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
|
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div);
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
#include <metal_math>
|
||||||
|
#
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -18,9 +21,15 @@ METAL_FUNC uint get_strided_index(
|
|||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
|
template <typename T> METAL_FUNC T gelu(T x){
|
||||||
|
T x_sq = x * x;
|
||||||
|
T x_cube = x_sq * x;
|
||||||
|
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||||
|
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||||
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanh(beta));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
@ -64,8 +73,14 @@ UNARY_OP(sqrt)
|
|||||||
UNARY_OP(neg)
|
UNARY_OP(neg)
|
||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
UNARY_OP(log)
|
UNARY_OP(log)
|
||||||
|
UNARY_OP(gelu)
|
||||||
|
UNARY_OP(ceil)
|
||||||
|
UNARY_OP(floor)
|
||||||
|
UNARY_OP(round)
|
||||||
UNARY(id, float, copy_float, copy_float_strided)
|
UNARY(id, float, copy_float, copy_float_strided)
|
||||||
UNARY(id, half, copy_half, copy_half_strided)
|
UNARY(id, half, copy_half, copy_half_strided)
|
||||||
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
|
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_UNARY_OP(cos)
|
BFLOAT_UNARY_OP(cos)
|
||||||
@ -75,6 +90,10 @@ BFLOAT_UNARY_OP(sqrt)
|
|||||||
BFLOAT_UNARY_OP(neg)
|
BFLOAT_UNARY_OP(neg)
|
||||||
BFLOAT_UNARY_OP(exp)
|
BFLOAT_UNARY_OP(exp)
|
||||||
BFLOAT_UNARY_OP(log)
|
BFLOAT_UNARY_OP(log)
|
||||||
|
BFLOAT_UNARY_OP(gelu)
|
||||||
|
BFLOAT_UNARY_OP(ceil)
|
||||||
|
BFLOAT_UNARY_OP(floor)
|
||||||
|
BFLOAT_UNARY_OP(round)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user