mlx fp8 gemm

This commit is contained in:
Ivar Flakstad
2025-05-06 09:55:45 +02:00
parent cf9d7bf24c
commit 6210fbe9d8
10 changed files with 806 additions and 109 deletions

View File

@ -3,8 +3,8 @@ mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::affine::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::where_cond::benches,

View File

@ -7,19 +7,25 @@ fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device) {
fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul"));
let name = match dtype {
DType::F32 => "matmul_f32",
DType::F16 => "matmul_f16",
DType::BF16 => "matmul_bf16",
DType::U8 => "matmul_fp8",
_ => unimplemented!("{dtype:?} matmul bench not implemented"),
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@ -36,8 +42,11 @@ fn run_bench(c: &mut Criterion, device: &Device) {
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let dtypes = vec![DType::F32, DType::F16, DType::BF16, DType::U8];
for device in handler.devices {
run_bench(c, &device);
for dtype in dtypes.clone() {
run_bench(c, &device, dtype);
}
}
}

View File

@ -1513,32 +1513,17 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let dtype = match self.dtype {
// Hijacking the U8 dtype to represent E5M2 fp8
DType::U8 => candle_metal_kernels::GemmDType::F8E5M2,
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
return Err(
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_mlx_gemm(
@ -1556,7 +1541,7 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(
buffer,
self.device.clone(),

View File

@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
float8 = "0.2.1"
metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"

View File

@ -1,4 +1,5 @@
#include <metal_stdlib>
#include <metal_limits>
METAL_FUNC uint get_strided_index(
uint idx,
@ -18,6 +19,377 @@ METAL_FUNC uint get_strided_index(
using namespace metal;
typedef unsigned char fp8_storage_t;
typedef unsigned short int fp8x2_storage_t;
typedef unsigned int fp8x4_storage_t;
template <typename T>
struct _fp8_cast_traits;
template <>
struct _fp8_cast_traits<float> {
typedef _fp_encoding_traits<float> traits;
typedef typename traits::encoding_type encoding_type;
constexpr static constant encoding_type head_mask = 0xFF800000;
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
constexpr static constant encoding_type mask = 0x7FFFFFFF;
};
template <>
struct _fp8_cast_traits<half> {
typedef _fp_encoding_traits<float> traits;
typedef typename traits::encoding_type encoding_type;
constexpr static constant encoding_type head_mask = 0xFC00;
constexpr static constant encoding_type mantissa_mask = 0x3FF;
constexpr static constant encoding_type mask = 0x7FFF;
};
enum _fp8_variant_t {
E4M3 = 0, // OCP E4M3
E5M2 = 1, // OCP E5M2
E4M3_FNUZ = 2, // Standard FP8
E5M2_FNUZ = 3, // BF8
};
typedef enum fp8_variant_t {
_E4M3 = 0, // OCP E4M3
_E5M2 = 1, // OCP E5M2
_E4M3_FNUZ = 2, // Standard FP8
_E5M2_FNUZ = 3, // BF8
} fp8_variant_t;
template<int variant>
struct _fp8_variant_traits;
template <>
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
constexpr static constant bool is_fnuz = false;
constexpr static constant int we = 4;
constexpr static constant int wm = 3;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
constexpr static constant bool is_fnuz = false;
constexpr static constant int we = 5;
constexpr static constant int wm = 2;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
constexpr static constant bool is_fnuz = true;
constexpr static constant int we = 4;
constexpr static constant int wm = 3;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
constexpr static constant bool is_fnuz = true;
constexpr static constant int we = 5;
constexpr static constant int wm = 2;
};
template<typename T, int variant>
struct _fp8_variant_cast_traits;
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x47600000;
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x43700000;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x47600000;
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x5F00;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x7B00;
constexpr static constant traits::encoding_type ifmin = 0xFB00;
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x5B80;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x7B00;
constexpr static constant traits::encoding_type ifmin = 0xFB00;
};
// TODO: Simplify. No need to support all fp8 variants immediately.
template <typename T, int variant = _fp8_variant_t::E4M3_FNUZ>
METAL_FUNC fp8_storage_t cast_to_fp8(T _x, bool clip = false, bool stoch = false, uint rng = 0) {
typedef _fp_encoding_traits<T> traits;
typedef typename traits::encoding_type bits;
typedef numeric_limits<bits> limits;
typedef _fp8_cast_traits<T> cast_traits;
typedef _fp8_variant_traits<variant> variant_traits;
typedef _fp8_variant_cast_traits<T, variant> variant_cast_traits;
constexpr bool is_fnuz = variant_traits::is_fnuz;
constexpr int we = variant_traits::we;
constexpr int wm = variant_traits::wm;
constexpr int mfmt = traits::exponent_shift;
constexpr int bias = traits::exponent_bias;
constexpr bits mask = cast_traits::mask;
bits x = as_type<bits>(_x);
bits head = x & cast_traits::head_mask;
bits mantissa = x & cast_traits::mantissa_mask;
int exponent = (head >> traits::exponent_shift) & traits::exponent_max;
bits sign = head >> traits::sign_shift;
bits signed_inf = 0;
unsigned int nan = 0;
if (is_fnuz) {
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
nan = 0x80;
} else {
if (we == 4) {
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
} else {
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
}
constexpr bits ifmax = variant_cast_traits::ifmax;
// Deal with inf and NaNs
if ((x & traits::inf_mask) == traits::inf_mask) {
if (is_fnuz || we == 4) return nan; // fnuz and OCP E4M3 has no INF
if (mantissa != 0) return nan; // NaN
return sign == 0 ? 0x7C : 0xFC; // E5M2 Inf
}
if ((x & mask) > ifmax) {
return signed_inf;
}
if (x == 0) {
return 0;
}
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) {
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent;
} else {
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
exponent_diff = f8_denormal_act_exponent - act_exponent;
} else {
exponent_diff = 0;
}
mantissa += (1ull << mfmt);
}
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
(1ull << (mfmt - wm + exponent_diff - 1));
if (exponent_diff > 0)
mantissa >>= exponent_diff;
else if (exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << mfmt);
f8_exponent =
(act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
unsigned long drop_mask = (1ull << (mfmt - wm)) - 1;
bool odd =
mantissa & (1ull << (mfmt - wm));
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
if (f8_exponent == 0) {
if ((1ull << mfmt) & mantissa) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
} else {
if ((1ull << (mfmt + 1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
const int max_exp = (1 << we) - 1;
if (f8_exponent > max_exp) {
if (clip) {
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
template <int variant>
METAL_FUNC half fp8_to_half(fp8_storage_t x);
template <>
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
typedef _fp_encoding_traits<half> traits;
typedef typename traits::encoding_type bits;
bits ur = x << 8U;
bits sign = ur & 0x8000U;
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
bits mantissa = (ur & 0x0700U) >> 1U;
unsigned char absx = 0x7FU & (unsigned char)x;
if (absx == 0x7FU) {
// return NaN
ur = 0x7FFFU;
} else if (exponent == 0x2000U) {
if (mantissa != 0U) {
// normalize
mantissa = (bits)(mantissa << 1U);
while ((mantissa & 0x0400U) == 0U) {
mantissa = (bits)(mantissa << 1U);
exponent = (bits)(exponent - 0x0400U);
}
// discard implicit leading bit
mantissa &= 0x03FFU;
} else {
// zero
exponent = 0U;
}
ur = (sign | exponent) | mantissa;
} else {
ur = (sign | exponent) | mantissa;
}
return as_type<half>(ur);
}
template <>
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
typedef _fp_encoding_traits<half> traits;
typedef typename traits::encoding_type bits;
bits ur = x << 8U;
if ((x & 0x7FFFU) > 0x7C00U) {
// return NaN
ur = 0x7FFFU;
}
return as_type<half>(ur);
}
template <typename T, int variant>
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
return static_cast<T>(fp8_to_half<variant>(x));
}
#define CAST_TO_FP8(name, T) \
CAST_TO_FP8_VARIANT(name##_E4M3, T, E4M3) \
CAST_TO_FP8_VARIANT(name##_E5M2, T, E5M2)
#define CAST_TO_FP8_VARIANT(name, T, FP8_VARIANT) \
kernel void name( \
constant size_t &dim, \
device const T *input, \
device fp8_storage_t *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = cast_to_fp8<T, FP8_VARIANT>(input[tid]); \
} \
kernel void name##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const T *input, \
device fp8_storage_t *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = cast_to_fp8<T, FP8_VARIANT>( \
input[get_strided_index(tid, num_dims, dims, strides)] \
); \
} \
#define CAST_FROM_FP8(name, T) \
CAST_FROM_FP8_VARIANT(name##_E4M3, T, E4M3) \
CAST_FROM_FP8_VARIANT(name##_E5M2, T, E5M2)
#define CAST_FROM_FP8_VARIANT(name, T, FP8_VARIANT) \
kernel void name( \
constant size_t &dim, \
device const fp8_storage_t *input, \
device T *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = cast_fp8_to<T, FP8_VARIANT>(input[tid]); \
} \
kernel void name##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const fp8_storage_t *input, \
device T *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = cast_fp8_to<T, FP8_VARIANT>( \
input[get_strided_index(tid, num_dims, dims, strides)] \
); \
} \
CAST_FROM_FP8(cast_fp8_f16, half)
CAST_FROM_FP8(cast_fp8_f32, float)
CAST_TO_FP8(cast_f32_fp8, float)
CAST_TO_FP8(cast_f16_fp8, half)
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \

View File

@ -37,12 +37,19 @@ pub enum DType {
I64,
U32,
U8,
FP8(FP8Variant),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FP8Variant {
E4M3,
E5M2,
}
impl DType {
fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U8 | Self::FP8(_) => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
@ -311,7 +318,12 @@ impl Kernels {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
.map_err(|e| {
// Makes metal errors easier to read.
// TODO: Remove.
panic!("{}", e);
return MetalKernelError::LoadLibraryError(e.to_string());
})?
};
libraries.insert(source, lib.clone());
Ok(lib)

View File

@ -11,6 +11,186 @@
using namespace metal;
// FP8 casting
enum _fp8_variant_t {
E4M3 = 0, // OCP E4M3
E5M2 = 1, // OCP E5M2
E4M3_FNUZ = 2, // Standard FP8
E5M2_FNUZ = 3, // BF8
};
template<int variant>
struct _fp8_variant_traits;
template <>
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
constexpr static constant bool is_fnuz = false;
constexpr static constant int we = 4;
constexpr static constant int wm = 3;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
constexpr static constant bool is_fnuz = false;
constexpr static constant int we = 5;
constexpr static constant int wm = 2;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
constexpr static constant bool is_fnuz = true;
constexpr static constant int we = 4;
constexpr static constant int wm = 3;
};
template <>
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
constexpr static constant bool is_fnuz = true;
constexpr static constant int we = 5;
constexpr static constant int wm = 2;
};
template <typename T>
struct _fp8_cast_traits;
template <>
struct _fp8_cast_traits<float> {
typedef _fp_encoding_traits<float> traits;
typedef typename traits::encoding_type encoding_type;
constexpr static constant encoding_type head_mask = 0xFF800000;
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
constexpr static constant encoding_type mask = 0x7FFFFFFF;
};
template <>
struct _fp8_cast_traits<half> {
typedef _fp_encoding_traits<float> traits;
typedef typename traits::encoding_type encoding_type;
constexpr static constant encoding_type head_mask = 0xFC00;
constexpr static constant encoding_type mantissa_mask = 0x3FF;
constexpr static constant encoding_type mask = 0x7FFF;
};
template<typename T, int variant>
struct _fp8_variant_cast_traits;
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x47600000;
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x43700000;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
typedef _fp_encoding_traits<float> traits;
constexpr static constant traits::encoding_type ifmax = 0x47600000;
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x5F00;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x7B00;
constexpr static constant traits::encoding_type ifmin = 0xFB00;
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x5B80;
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
};
template <>
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
typedef _fp_encoding_traits<half> traits;
constexpr static constant traits::encoding_type ifmax = 0x7B00;
constexpr static constant traits::encoding_type ifmin = 0xFB00;
};
typedef unsigned char fp8_storage_t;
typedef unsigned short int fp8x2_storage_t;
typedef unsigned int fp8x4_storage_t;
template <int variant>
METAL_FUNC half fp8_to_half(fp8_storage_t x);
template <>
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
typedef _fp_encoding_traits<half> traits;
typedef _fp8_cast_traits<half> cast_traits;
typedef typename traits::encoding_type bits;
typedef numeric_limits<bits> limits;
typedef _fp8_variant_traits<_fp8_variant_t::E4M3> variant_traits;
typedef _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> variant_cast_traits;
bits ur = x << 8U;
bits sign = ur & 0x8000U;
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
bits mantissa = (ur & 0x0700U) >> 1U;
unsigned char absx = 0x7FU & (unsigned char)x;
if (absx == 0x7FU) { // NaN
ur = 0x7FFFU; // fp16 canonical NaN, discard sign
} else if (exponent == 0x2000U) {
// zero or denormal
if (mantissa != 0U) {
// normalize
mantissa = (bits)(mantissa << 1U);
while ((mantissa & 0x0400U) == 0U) {
mantissa = (bits)(mantissa << 1U);
exponent = (bits)(exponent - 0x0400U);
}
// discard implicit leading bit
mantissa &= 0x03FFU;
} else { // Zero
exponent = 0U;
}
ur = (sign | exponent) | mantissa;
} else {
ur = (sign | exponent) | mantissa;
}
return as_type<half>(ur);
}
template <>
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
typedef _fp_encoding_traits<half> traits;
typedef _fp8_cast_traits<half> cast_traits;
typedef typename traits::encoding_type bits;
typedef numeric_limits<bits> limits;
bits ur = x << 8U;
if ((x & 0x7FFFU) > 0x7C00U) {
/* If NaN, return canonical NaN */
ur = 0x7FFFU;
}
return as_type<half>(ur);
}
template <typename T, int variant>
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
return static_cast<T>(fp8_to_half<variant>(x));
}
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1
///////////////////////////////////////////////////////////////////////////////
// GEMM param classes
@ -199,6 +379,35 @@ struct BlockLoader {
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
template<typename OutT, typename InT, typename _E = void>
METAL_FUNC OutT accum_cast(InT x);
template<>
METAL_FUNC float accum_cast<float, float, void>(float x) {
return x;
}
template<>
METAL_FUNC float accum_cast<float, half, void>(half x) {
return static_cast<float>(x);
}
#if defined(__HAVE_BFLOAT__)
template<>
METAL_FUNC float accum_cast<float, bfloat, void>(bfloat x) {
return static_cast<float>(x);
}
#endif
template<>
METAL_FUNC float accum_cast(fp8_storage_t x) {
return cast_fp8_to<float, E5M2>(x);
}
template<>
METAL_FUNC half accum_cast(fp8_storage_t x) {
return cast_fp8_to<half, E5M2>(x);
}
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
@ -210,6 +419,17 @@ struct TransformNone {
}
};
template <typename OutT>
struct TransformNone<OutT, fp8_storage_t> {
static METAL_FUNC OutT apply(fp8_storage_t x) {
return cast_fp8_to<OutT, E5M2>(x);
}
static METAL_FUNC OutT apply(fp8_storage_t x, OutT) {
return cast_fp8_to<OutT, E5M2>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
@ -223,6 +443,20 @@ struct TransformAdd {
}
};
template <typename OutT>
struct TransformAdd<OutT, fp8_storage_t> {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(fp8_storage_t x) {
return cast_fp8_to<OutT, E5M2>(x);
}
static METAL_FUNC OutT apply(fp8_storage_t x, OutT c) {
return cast_fp8_to<OutT, E5M2>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
@ -240,6 +474,24 @@ struct TransformAxpby {
}
};
template <typename OutT>
struct TransformAxpby<OutT, fp8_storage_t> {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(fp8_storage_t x) {
return cast_fp8_to<OutT, E5M2>(x);
}
METAL_FUNC OutT apply(fp8_storage_t x, OutT c) const {
return cast_fp8_to<OutT, E5M2>(x) * alpha + (beta * c);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
@ -336,7 +588,6 @@ struct BlockMMA {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
@ -345,23 +596,19 @@ struct BlockMMA {
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
Asimd[i].thread_elements()[0] = accum_cast<AccumType, T>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] = accum_cast<AccumType, T>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
Bsimd[j].thread_elements()[0] = accum_cast<AccumType, T>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] = accum_cast<AccumType, T>(Bs[j * simd_stride_b + jump_b]);
}
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
@ -370,7 +617,6 @@ struct BlockMMA {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
@ -382,7 +628,7 @@ struct BlockMMA {
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
@ -1020,12 +1266,13 @@ template <
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
typename AccumType = float,
typename ResultType = T>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const device ResultType* C [[buffer(2), function_constant(use_out_source)]],
device ResultType* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
@ -1045,7 +1292,7 @@ template <
using gemm_kernel = GEMMKernel<
T,
T,
ResultType,
BM,
BN,
BK,
@ -1405,13 +1652,13 @@ template <
}
}
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float, stype>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
device itype *D [[buffer(3)]], \
const device stype *C [[buffer(2), function_constant(use_out_source)]], \
device stype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
const constant int* batch_shape [[buffer(6)]], \
@ -1427,14 +1674,17 @@ template <
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype)
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2, float)
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2, half)
#if defined(__HAVE_BFLOAT__)
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2, bfloat)
#endif
instantiate_gemm_transpose_helper(F8E5M2, fp8_storage_t, f16, half, 32, 32, 16, 2, 2, half)

View File

@ -5,6 +5,7 @@ use std::ffi::c_void;
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
F8E5M2,
BF16,
F16,
F32,
@ -138,6 +139,10 @@ pub fn call_mlx_gemm(
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
(GemmDType::F8E5M2, false, false) => "gemm_nn_F8E5M2_f16_32_32_16_2_2",
(GemmDType::F8E5M2, true, false) => "gemm_tn_F8E5M2_f16_32_32_16_2_2",
(GemmDType::F8E5M2, false, true) => "gemm_nt_F8E5M2_f16_32_32_16_2_2",
(GemmDType::F8E5M2, true, true) => "gemm_tt_F8E5M2_f16_32_32_16_2_2",
};
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();

View File

@ -41,6 +41,7 @@ pub fn call_arg_sort(
fn mlx_dtype_str(dtype: DType) -> &'static str {
match dtype {
DType::FP8(_) => todo!(),
DType::U8 => "uint8",
DType::U32 => "uint32",
DType::I64 => "int64",
@ -195,6 +196,7 @@ pub fn multi_block_sort(
};
// Copy output with appropriate strides
let copy_kernel = match dtype {
DType::FP8(_) => todo!(),
DType::U8 => crate::copy2d::U8,
DType::U32 => crate::copy2d::U32,
DType::I64 => crate::copy2d::I64,

View File

@ -1,4 +1,6 @@
use super::*;
use float8::{F8E4M3, F8E5M2};
use half::{bf16, f16};
use metal::{Buffer, Device, MTLResourceOptions};
use rand::prelude::SliceRandom;
@ -423,6 +425,38 @@ fn cast_bf16() {
assert_eq!(results, v_i64);
}
#[test]
fn cast_fp8() {
let v_f64 = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_f8: Vec<F8E4M3> = v_f32.iter().map(|&v| F8E4M3::from(v)).collect();
let v_f16_rt: Vec<f16> = v_f8.iter().map(|&v| v.to_f16()).collect();
// E4M3
// f32 -> fp8 -> f32
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E4M3");
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E4M3");
assert_eq!(v_f32, roundtrip);
// f16 -> fp8 -> f16
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E4M3");
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E4M3");
assert_eq!(v_f16, roundtrip);
// E5M2
// f16 -> fp8 -> f16
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E5M2");
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E5M2");
assert_eq!(v_f16, roundtrip);
// f32 -> fp8 -> f32
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E5M2");
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E5M2");
assert_eq!(v_f32, roundtrip);
}
#[test]
fn cast_u32() {
let v_f64 = [1.0f64, 2.0, 3.0];
@ -1285,7 +1319,7 @@ fn where_cond_u32_f32() {
}
#[allow(clippy::too_many_arguments)]
fn run_mlx_gemm<T: Clone>(
fn run_mlx_gemm<T: Clone, U: Clone>(
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
@ -1294,7 +1328,7 @@ fn run_mlx_gemm<T: Clone>(
rhs: &[T],
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@ -1436,6 +1470,33 @@ fn mlx_gemm() {
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
{
// F8E5M2 sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<u8> = (0..b * m * k)
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
.collect();
let rhs: Vec<u8> = (0..b * n * k)
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
.collect();
let results: Vec<f16> = run_mlx_gemm(
GemmDType::F8E5M2,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
println!("{results:?}");
assert_eq!(
approx_f16(results, 4),
vec![20.0, 21.0, 26.0, 31.0, 56.0, 63.0, 80.0, 97.0]
);
}
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {