mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
mlx fp8 gemm
This commit is contained in:
@ -3,8 +3,8 @@ mod benchmarks;
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
|
@ -7,19 +7,25 @@ fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
let name = match dtype {
|
||||
DType::F32 => "matmul_f32",
|
||||
DType::F16 => "matmul_f16",
|
||||
DType::BF16 => "matmul_bf16",
|
||||
DType::U8 => "matmul_fp8",
|
||||
_ => unimplemented!("{dtype:?} matmul bench not implemented"),
|
||||
};
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -36,8 +42,11 @@ fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let dtypes = vec![DType::F32, DType::F16, DType::BF16, DType::U8];
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
for dtype in dtypes.clone() {
|
||||
run_bench(c, &device, dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1513,32 +1513,17 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
if self.dtype == DType::BF16 {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
candle_metal_kernels::GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
|
||||
let dtype = match self.dtype {
|
||||
// Hijacking the U8 dtype to represent E5M2 fp8
|
||||
DType::U8 => candle_metal_kernels::GemmDType::F8E5M2,
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!(
|
||||
"mlx matmul doesn't support {dtype:?}"
|
||||
))
|
||||
.into())
|
||||
return Err(
|
||||
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
@ -1556,7 +1541,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
|
@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
float8 = "0.2.1"
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_limits>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -18,6 +19,377 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
using namespace metal;
|
||||
|
||||
typedef unsigned char fp8_storage_t;
|
||||
typedef unsigned short int fp8x2_storage_t;
|
||||
typedef unsigned int fp8x4_storage_t;
|
||||
|
||||
template <typename T>
|
||||
struct _fp8_cast_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_cast_traits<float> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
typedef typename traits::encoding_type encoding_type;
|
||||
constexpr static constant encoding_type head_mask = 0xFF800000;
|
||||
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
|
||||
constexpr static constant encoding_type mask = 0x7FFFFFFF;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _fp8_cast_traits<half> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
typedef typename traits::encoding_type encoding_type;
|
||||
constexpr static constant encoding_type head_mask = 0xFC00;
|
||||
constexpr static constant encoding_type mantissa_mask = 0x3FF;
|
||||
constexpr static constant encoding_type mask = 0x7FFF;
|
||||
};
|
||||
|
||||
enum _fp8_variant_t {
|
||||
E4M3 = 0, // OCP E4M3
|
||||
E5M2 = 1, // OCP E5M2
|
||||
E4M3_FNUZ = 2, // Standard FP8
|
||||
E5M2_FNUZ = 3, // BF8
|
||||
};
|
||||
|
||||
|
||||
typedef enum fp8_variant_t {
|
||||
_E4M3 = 0, // OCP E4M3
|
||||
_E5M2 = 1, // OCP E5M2
|
||||
_E4M3_FNUZ = 2, // Standard FP8
|
||||
_E5M2_FNUZ = 3, // BF8
|
||||
} fp8_variant_t;
|
||||
|
||||
|
||||
template<int variant>
|
||||
struct _fp8_variant_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
|
||||
constexpr static constant bool is_fnuz = false;
|
||||
constexpr static constant int we = 4;
|
||||
constexpr static constant int wm = 3;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
|
||||
constexpr static constant bool is_fnuz = false;
|
||||
constexpr static constant int we = 5;
|
||||
constexpr static constant int wm = 2;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
|
||||
constexpr static constant bool is_fnuz = true;
|
||||
constexpr static constant int we = 4;
|
||||
constexpr static constant int wm = 3;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
|
||||
constexpr static constant bool is_fnuz = true;
|
||||
constexpr static constant int we = 5;
|
||||
constexpr static constant int wm = 2;
|
||||
};
|
||||
|
||||
|
||||
template<typename T, int variant>
|
||||
struct _fp8_variant_cast_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x43700000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x5F00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x5B80;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||
};
|
||||
|
||||
// TODO: Simplify. No need to support all fp8 variants immediately.
|
||||
template <typename T, int variant = _fp8_variant_t::E4M3_FNUZ>
|
||||
METAL_FUNC fp8_storage_t cast_to_fp8(T _x, bool clip = false, bool stoch = false, uint rng = 0) {
|
||||
typedef _fp_encoding_traits<T> traits;
|
||||
typedef typename traits::encoding_type bits;
|
||||
typedef numeric_limits<bits> limits;
|
||||
|
||||
typedef _fp8_cast_traits<T> cast_traits;
|
||||
typedef _fp8_variant_traits<variant> variant_traits;
|
||||
typedef _fp8_variant_cast_traits<T, variant> variant_cast_traits;
|
||||
|
||||
constexpr bool is_fnuz = variant_traits::is_fnuz;
|
||||
constexpr int we = variant_traits::we;
|
||||
constexpr int wm = variant_traits::wm;
|
||||
constexpr int mfmt = traits::exponent_shift;
|
||||
constexpr int bias = traits::exponent_bias;
|
||||
constexpr bits mask = cast_traits::mask;
|
||||
|
||||
bits x = as_type<bits>(_x);
|
||||
|
||||
bits head = x & cast_traits::head_mask;
|
||||
bits mantissa = x & cast_traits::mantissa_mask;
|
||||
int exponent = (head >> traits::exponent_shift) & traits::exponent_max;
|
||||
bits sign = head >> traits::sign_shift;
|
||||
|
||||
bits signed_inf = 0;
|
||||
unsigned int nan = 0;
|
||||
if (is_fnuz) {
|
||||
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||
nan = 0x80;
|
||||
} else {
|
||||
if (we == 4) {
|
||||
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||
} else {
|
||||
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||
}
|
||||
nan = (sign << 7) + 0x7f;
|
||||
}
|
||||
constexpr bits ifmax = variant_cast_traits::ifmax;
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if ((x & traits::inf_mask) == traits::inf_mask) {
|
||||
if (is_fnuz || we == 4) return nan; // fnuz and OCP E4M3 has no INF
|
||||
if (mantissa != 0) return nan; // NaN
|
||||
return sign == 0 ? 0x7C : 0xFC; // E5M2 Inf
|
||||
}
|
||||
|
||||
if ((x & mask) > ifmax) {
|
||||
return signed_inf;
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) {
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else {
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else {
|
||||
exponent_diff = 0;
|
||||
}
|
||||
mantissa += (1ull << mfmt);
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
(1ull << (mfmt - wm + exponent_diff - 1));
|
||||
|
||||
if (exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if (exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1ull << mfmt);
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
unsigned long drop_mask = (1ull << (mfmt - wm)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1ull << (mfmt - wm));
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||
|
||||
if (f8_exponent == 0) {
|
||||
if ((1ull << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1ull << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
const int max_exp = (1 << we) - 1;
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int variant>
|
||||
METAL_FUNC half fp8_to_half(fp8_storage_t x);
|
||||
|
||||
template <>
|
||||
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
typedef typename traits::encoding_type bits;
|
||||
|
||||
bits ur = x << 8U;
|
||||
bits sign = ur & 0x8000U;
|
||||
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
|
||||
bits mantissa = (ur & 0x0700U) >> 1U;
|
||||
unsigned char absx = 0x7FU & (unsigned char)x;
|
||||
|
||||
if (absx == 0x7FU) {
|
||||
// return NaN
|
||||
ur = 0x7FFFU;
|
||||
} else if (exponent == 0x2000U) {
|
||||
if (mantissa != 0U) {
|
||||
// normalize
|
||||
mantissa = (bits)(mantissa << 1U);
|
||||
while ((mantissa & 0x0400U) == 0U) {
|
||||
mantissa = (bits)(mantissa << 1U);
|
||||
exponent = (bits)(exponent - 0x0400U);
|
||||
}
|
||||
// discard implicit leading bit
|
||||
mantissa &= 0x03FFU;
|
||||
} else {
|
||||
// zero
|
||||
exponent = 0U;
|
||||
}
|
||||
|
||||
ur = (sign | exponent) | mantissa;
|
||||
} else {
|
||||
ur = (sign | exponent) | mantissa;
|
||||
}
|
||||
|
||||
return as_type<half>(ur);
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
typedef typename traits::encoding_type bits;
|
||||
|
||||
bits ur = x << 8U;
|
||||
if ((x & 0x7FFFU) > 0x7C00U) {
|
||||
// return NaN
|
||||
ur = 0x7FFFU;
|
||||
}
|
||||
return as_type<half>(ur);
|
||||
}
|
||||
|
||||
template <typename T, int variant>
|
||||
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
|
||||
return static_cast<T>(fp8_to_half<variant>(x));
|
||||
}
|
||||
|
||||
#define CAST_TO_FP8(name, T) \
|
||||
CAST_TO_FP8_VARIANT(name##_E4M3, T, E4M3) \
|
||||
CAST_TO_FP8_VARIANT(name##_E5M2, T, E5M2)
|
||||
|
||||
#define CAST_TO_FP8_VARIANT(name, T, FP8_VARIANT) \
|
||||
kernel void name( \
|
||||
constant size_t &dim, \
|
||||
device const T *input, \
|
||||
device fp8_storage_t *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = cast_to_fp8<T, FP8_VARIANT>(input[tid]); \
|
||||
} \
|
||||
kernel void name##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const T *input, \
|
||||
device fp8_storage_t *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = cast_to_fp8<T, FP8_VARIANT>( \
|
||||
input[get_strided_index(tid, num_dims, dims, strides)] \
|
||||
); \
|
||||
} \
|
||||
|
||||
#define CAST_FROM_FP8(name, T) \
|
||||
CAST_FROM_FP8_VARIANT(name##_E4M3, T, E4M3) \
|
||||
CAST_FROM_FP8_VARIANT(name##_E5M2, T, E5M2)
|
||||
|
||||
#define CAST_FROM_FP8_VARIANT(name, T, FP8_VARIANT) \
|
||||
kernel void name( \
|
||||
constant size_t &dim, \
|
||||
device const fp8_storage_t *input, \
|
||||
device T *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = cast_fp8_to<T, FP8_VARIANT>(input[tid]); \
|
||||
} \
|
||||
kernel void name##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const fp8_storage_t *input, \
|
||||
device T *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = cast_fp8_to<T, FP8_VARIANT>( \
|
||||
input[get_strided_index(tid, num_dims, dims, strides)] \
|
||||
); \
|
||||
} \
|
||||
|
||||
CAST_FROM_FP8(cast_fp8_f16, half)
|
||||
CAST_FROM_FP8(cast_fp8_f32, float)
|
||||
CAST_TO_FP8(cast_f32_fp8, float)
|
||||
CAST_TO_FP8(cast_f16_fp8, half)
|
||||
|
||||
|
||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
|
@ -37,12 +37,19 @@ pub enum DType {
|
||||
I64,
|
||||
U32,
|
||||
U8,
|
||||
FP8(FP8Variant),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum FP8Variant {
|
||||
E4M3,
|
||||
E5M2,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 1,
|
||||
Self::U8 | Self::FP8(_) => 1,
|
||||
Self::U32 => 4,
|
||||
Self::I64 => 8,
|
||||
Self::BF16 => 2,
|
||||
@ -311,7 +318,12 @@ impl Kernels {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
.map_err(|e| {
|
||||
// Makes metal errors easier to read.
|
||||
// TODO: Remove.
|
||||
panic!("{}", e);
|
||||
return MetalKernelError::LoadLibraryError(e.to_string());
|
||||
})?
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
|
@ -11,6 +11,186 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// FP8 casting
|
||||
|
||||
enum _fp8_variant_t {
|
||||
E4M3 = 0, // OCP E4M3
|
||||
E5M2 = 1, // OCP E5M2
|
||||
E4M3_FNUZ = 2, // Standard FP8
|
||||
E5M2_FNUZ = 3, // BF8
|
||||
};
|
||||
|
||||
template<int variant>
|
||||
struct _fp8_variant_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
|
||||
constexpr static constant bool is_fnuz = false;
|
||||
constexpr static constant int we = 4;
|
||||
constexpr static constant int wm = 3;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
|
||||
constexpr static constant bool is_fnuz = false;
|
||||
constexpr static constant int we = 5;
|
||||
constexpr static constant int wm = 2;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
|
||||
constexpr static constant bool is_fnuz = true;
|
||||
constexpr static constant int we = 4;
|
||||
constexpr static constant int wm = 3;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
|
||||
constexpr static constant bool is_fnuz = true;
|
||||
constexpr static constant int we = 5;
|
||||
constexpr static constant int wm = 2;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct _fp8_cast_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_cast_traits<float> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
typedef typename traits::encoding_type encoding_type;
|
||||
constexpr static constant encoding_type head_mask = 0xFF800000;
|
||||
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
|
||||
constexpr static constant encoding_type mask = 0x7FFFFFFF;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _fp8_cast_traits<half> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
typedef typename traits::encoding_type encoding_type;
|
||||
constexpr static constant encoding_type head_mask = 0xFC00;
|
||||
constexpr static constant encoding_type mantissa_mask = 0x3FF;
|
||||
constexpr static constant encoding_type mask = 0x7FFF;
|
||||
};
|
||||
|
||||
template<typename T, int variant>
|
||||
struct _fp8_variant_cast_traits;
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x43700000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
|
||||
typedef _fp_encoding_traits<float> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x5F00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x5B80;
|
||||
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||
};
|
||||
template <>
|
||||
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||
};
|
||||
|
||||
typedef unsigned char fp8_storage_t;
|
||||
typedef unsigned short int fp8x2_storage_t;
|
||||
typedef unsigned int fp8x4_storage_t;
|
||||
|
||||
template <int variant>
|
||||
METAL_FUNC half fp8_to_half(fp8_storage_t x);
|
||||
|
||||
template <>
|
||||
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
typedef _fp8_cast_traits<half> cast_traits;
|
||||
typedef typename traits::encoding_type bits;
|
||||
typedef numeric_limits<bits> limits;
|
||||
|
||||
typedef _fp8_variant_traits<_fp8_variant_t::E4M3> variant_traits;
|
||||
typedef _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> variant_cast_traits;
|
||||
|
||||
bits ur = x << 8U;
|
||||
bits sign = ur & 0x8000U;
|
||||
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
|
||||
bits mantissa = (ur & 0x0700U) >> 1U;
|
||||
unsigned char absx = 0x7FU & (unsigned char)x;
|
||||
|
||||
if (absx == 0x7FU) { // NaN
|
||||
ur = 0x7FFFU; // fp16 canonical NaN, discard sign
|
||||
} else if (exponent == 0x2000U) {
|
||||
// zero or denormal
|
||||
if (mantissa != 0U) {
|
||||
// normalize
|
||||
mantissa = (bits)(mantissa << 1U);
|
||||
while ((mantissa & 0x0400U) == 0U) {
|
||||
mantissa = (bits)(mantissa << 1U);
|
||||
exponent = (bits)(exponent - 0x0400U);
|
||||
}
|
||||
// discard implicit leading bit
|
||||
mantissa &= 0x03FFU;
|
||||
} else { // Zero
|
||||
exponent = 0U;
|
||||
}
|
||||
|
||||
ur = (sign | exponent) | mantissa;
|
||||
} else {
|
||||
ur = (sign | exponent) | mantissa;
|
||||
}
|
||||
|
||||
return as_type<half>(ur);
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
|
||||
typedef _fp_encoding_traits<half> traits;
|
||||
typedef _fp8_cast_traits<half> cast_traits;
|
||||
typedef typename traits::encoding_type bits;
|
||||
typedef numeric_limits<bits> limits;
|
||||
|
||||
bits ur = x << 8U;
|
||||
if ((x & 0x7FFFU) > 0x7C00U) {
|
||||
/* If NaN, return canonical NaN */
|
||||
ur = 0x7FFFU;
|
||||
}
|
||||
return as_type<half>(ur);
|
||||
}
|
||||
|
||||
template <typename T, int variant>
|
||||
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
|
||||
return static_cast<T>(fp8_to_half<variant>(x));
|
||||
}
|
||||
|
||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM param classes
|
||||
@ -199,6 +379,35 @@ struct BlockLoader {
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename OutT, typename InT, typename _E = void>
|
||||
METAL_FUNC OutT accum_cast(InT x);
|
||||
|
||||
template<>
|
||||
METAL_FUNC float accum_cast<float, float, void>(float x) {
|
||||
return x;
|
||||
}
|
||||
template<>
|
||||
METAL_FUNC float accum_cast<float, half, void>(half x) {
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
template<>
|
||||
METAL_FUNC float accum_cast<float, bfloat, void>(bfloat x) {
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
METAL_FUNC float accum_cast(fp8_storage_t x) {
|
||||
return cast_fp8_to<float, E5M2>(x);
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC half accum_cast(fp8_storage_t x) {
|
||||
return cast_fp8_to<half, E5M2>(x);
|
||||
}
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
@ -210,6 +419,17 @@ struct TransformNone {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT>
|
||||
struct TransformNone<OutT, fp8_storage_t> {
|
||||
static METAL_FUNC OutT apply(fp8_storage_t x) {
|
||||
return cast_fp8_to<OutT, E5M2>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(fp8_storage_t x, OutT) {
|
||||
return cast_fp8_to<OutT, E5M2>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
@ -223,6 +443,20 @@ struct TransformAdd {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename OutT>
|
||||
struct TransformAdd<OutT, fp8_storage_t> {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(fp8_storage_t x) {
|
||||
return cast_fp8_to<OutT, E5M2>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(fp8_storage_t x, OutT c) {
|
||||
return cast_fp8_to<OutT, E5M2>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
@ -240,6 +474,24 @@ struct TransformAxpby {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename OutT>
|
||||
struct TransformAxpby<OutT, fp8_storage_t> {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
static METAL_FUNC OutT apply(fp8_storage_t x) {
|
||||
return cast_fp8_to<OutT, E5M2>(x);
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(fp8_storage_t x, OutT c) const {
|
||||
return cast_fp8_to<OutT, E5M2>(x) * alpha + (beta * c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
@ -336,7 +588,6 @@ struct BlockMMA {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
@ -345,23 +596,19 @@ struct BlockMMA {
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
Asimd[i].thread_elements()[0] = accum_cast<AccumType, T>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] = accum_cast<AccumType, T>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
Bsimd[j].thread_elements()[0] = accum_cast<AccumType, T>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] = accum_cast<AccumType, T>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
@ -370,7 +617,6 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
@ -382,7 +628,7 @@ struct BlockMMA {
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
@ -1020,12 +1266,13 @@ template <
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
typename AccumType = float,
|
||||
typename ResultType = T>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const device ResultType* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device ResultType* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
@ -1045,7 +1292,7 @@ template <
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
ResultType,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
@ -1405,13 +1652,13 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float, stype>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const device stype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||
device stype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
@ -1427,14 +1674,17 @@ template <
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype)
|
||||
|
||||
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2, float)
|
||||
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2, bfloat)
|
||||
#endif
|
||||
|
||||
|
||||
instantiate_gemm_transpose_helper(F8E5M2, fp8_storage_t, f16, half, 32, 32, 16, 2, 2, half)
|
||||
|
@ -5,6 +5,7 @@ use std::ffi::c_void;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum GemmDType {
|
||||
F8E5M2,
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
@ -138,6 +139,10 @@ pub fn call_mlx_gemm(
|
||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F8E5M2, false, false) => "gemm_nn_F8E5M2_f16_32_32_16_2_2",
|
||||
(GemmDType::F8E5M2, true, false) => "gemm_tn_F8E5M2_f16_32_32_16_2_2",
|
||||
(GemmDType::F8E5M2, false, true) => "gemm_nt_F8E5M2_f16_32_32_16_2_2",
|
||||
(GemmDType::F8E5M2, true, true) => "gemm_tt_F8E5M2_f16_32_32_16_2_2",
|
||||
};
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -41,6 +41,7 @@ pub fn call_arg_sort(
|
||||
|
||||
fn mlx_dtype_str(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::FP8(_) => todo!(),
|
||||
DType::U8 => "uint8",
|
||||
DType::U32 => "uint32",
|
||||
DType::I64 => "int64",
|
||||
@ -195,6 +196,7 @@ pub fn multi_block_sort(
|
||||
};
|
||||
// Copy output with appropriate strides
|
||||
let copy_kernel = match dtype {
|
||||
DType::FP8(_) => todo!(),
|
||||
DType::U8 => crate::copy2d::U8,
|
||||
DType::U32 => crate::copy2d::U32,
|
||||
DType::I64 => crate::copy2d::I64,
|
||||
|
@ -1,4 +1,6 @@
|
||||
use super::*;
|
||||
|
||||
use float8::{F8E4M3, F8E5M2};
|
||||
use half::{bf16, f16};
|
||||
use metal::{Buffer, Device, MTLResourceOptions};
|
||||
use rand::prelude::SliceRandom;
|
||||
@ -423,6 +425,38 @@ fn cast_bf16() {
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_fp8() {
|
||||
let v_f64 = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let v_f8: Vec<F8E4M3> = v_f32.iter().map(|&v| F8E4M3::from(v)).collect();
|
||||
let v_f16_rt: Vec<f16> = v_f8.iter().map(|&v| v.to_f16()).collect();
|
||||
|
||||
// E4M3
|
||||
// f32 -> fp8 -> f32
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E4M3");
|
||||
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E4M3");
|
||||
assert_eq!(v_f32, roundtrip);
|
||||
|
||||
// f16 -> fp8 -> f16
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E4M3");
|
||||
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E4M3");
|
||||
assert_eq!(v_f16, roundtrip);
|
||||
|
||||
// E5M2
|
||||
// f16 -> fp8 -> f16
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E5M2");
|
||||
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E5M2");
|
||||
assert_eq!(v_f16, roundtrip);
|
||||
|
||||
// f32 -> fp8 -> f32
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E5M2");
|
||||
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E5M2");
|
||||
assert_eq!(v_f32, roundtrip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
@ -1285,7 +1319,7 @@ fn where_cond_u32_f32() {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_mlx_gemm<T: Clone>(
|
||||
fn run_mlx_gemm<T: Clone, U: Clone>(
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
@ -1294,7 +1328,7 @@ fn run_mlx_gemm<T: Clone>(
|
||||
rhs: &[T],
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -1436,6 +1470,33 @@ fn mlx_gemm() {
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// F8E5M2 sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<u8> = (0..b * m * k)
|
||||
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
|
||||
.collect();
|
||||
let rhs: Vec<u8> = (0..b * n * k)
|
||||
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
|
||||
.collect();
|
||||
let results: Vec<f16> = run_mlx_gemm(
|
||||
GemmDType::F8E5M2,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
println!("{results:?}");
|
||||
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![20.0, 21.0, 26.0, 31.0, 56.0, 63.0, 80.0, 97.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
|
Reference in New Issue
Block a user