mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
3 Commits
0.9.1
...
metal-fp8-
Author | SHA1 | Date | |
---|---|---|---|
5ed764213d | |||
816aeeb7b6 | |||
6210fbe9d8 |
@ -3,8 +3,8 @@ mod benchmarks;
|
|||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
|
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
|
benchmarks::affine::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::random::benches,
|
||||||
benchmarks::reduce::benches,
|
benchmarks::reduce::benches,
|
||||||
benchmarks::where_cond::benches,
|
benchmarks::where_cond::benches,
|
||||||
|
@ -7,20 +7,27 @@ fn run(a: &Tensor, b: &Tensor) {
|
|||||||
a.matmul(&b.t().unwrap()).unwrap();
|
a.matmul(&b.t().unwrap()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 1;
|
let m = 1;
|
||||||
let n = 2048;
|
let n = 2048;
|
||||||
let k = 2048;
|
let k = 2048;
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
let flops = b * m * n * k;
|
||||||
|
let bytes = flops * dtype.size_in_bytes();
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
let name = match dtype {
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
DType::F32 => "matmul_f32",
|
||||||
|
DType::U8 => "matmul_fp8",
|
||||||
|
DType::F16 => "matmul_f16",
|
||||||
|
DType::BF16 => "matmul_bf16",
|
||||||
|
_ => unimplemented!("{dtype:?} matmul bench not implemented"),
|
||||||
|
};
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(bytes as u64));
|
||||||
group.bench_function("iter", move |b| {
|
group.bench_function("iter", move |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@ -36,8 +43,11 @@ fn run_bench(c: &mut Criterion, device: &Device) {
|
|||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
let dtypes = vec![DType::F32, DType::U8, DType::F16, DType::BF16];
|
||||||
for device in handler.devices {
|
for device in handler.devices {
|
||||||
run_bench(c, &device);
|
for dtype in dtypes.clone() {
|
||||||
|
run_bench(c, &device, dtype);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1513,50 +1513,35 @@ impl BackendStorage for MetalStorage {
|
|||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
if self.dtype == DType::BF16 {
|
|
||||||
candle_metal_kernels::call_mlx_gemm(
|
let dtype = match self.dtype {
|
||||||
&self.device.device,
|
// Hijacking the U8 dtype to represent E5M2 fp8
|
||||||
&command_buffer,
|
DType::U8 => candle_metal_kernels::GemmDType::F8E5M2,
|
||||||
&self.device.kernels,
|
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||||
candle_metal_kernels::GemmDType::BF16,
|
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||||
(b, m, n, k),
|
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||||
lhs_l.stride(),
|
dtype => {
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
return Err(
|
||||||
&self.buffer,
|
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||||
rhs_l.stride(),
|
)
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
}
|
||||||
&rhs.buffer,
|
};
|
||||||
&buffer,
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
)
|
&self.device.device,
|
||||||
.map_err(MetalError::from)?;
|
&command_buffer,
|
||||||
} else {
|
&self.device.kernels,
|
||||||
let dtype = match self.dtype {
|
dtype,
|
||||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
(b, m, n, k),
|
||||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
lhs_l.stride(),
|
||||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
dtype => {
|
&self.buffer,
|
||||||
return Err(MetalError::Message(format!(
|
rhs_l.stride(),
|
||||||
"mlx matmul doesn't support {dtype:?}"
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
))
|
&rhs.buffer,
|
||||||
.into())
|
&buffer,
|
||||||
}
|
)
|
||||||
};
|
.map_err(MetalError::from)?;
|
||||||
candle_metal_kernels::call_mlx_gemm(
|
|
||||||
&self.device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.device.kernels,
|
|
||||||
dtype,
|
|
||||||
(b, m, n, k),
|
|
||||||
lhs_l.stride(),
|
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
rhs_l.stride(),
|
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
|
||||||
&rhs.buffer,
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
self.device.clone(),
|
self.device.clone(),
|
||||||
|
@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
float8 = "0.2.1"
|
||||||
metal = { version = "0.27.0", features = ["mps"] }
|
metal = { version = "0.27.0", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
#include <metal_limits>
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -18,6 +19,377 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
typedef unsigned char fp8_storage_t;
|
||||||
|
typedef unsigned short int fp8x2_storage_t;
|
||||||
|
typedef unsigned int fp8x4_storage_t;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct _fp8_cast_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_cast_traits<float> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
typedef typename traits::encoding_type encoding_type;
|
||||||
|
constexpr static constant encoding_type head_mask = 0xFF800000;
|
||||||
|
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
|
||||||
|
constexpr static constant encoding_type mask = 0x7FFFFFFF;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_cast_traits<half> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
typedef typename traits::encoding_type encoding_type;
|
||||||
|
constexpr static constant encoding_type head_mask = 0xFC00;
|
||||||
|
constexpr static constant encoding_type mantissa_mask = 0x3FF;
|
||||||
|
constexpr static constant encoding_type mask = 0x7FFF;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum _fp8_variant_t {
|
||||||
|
E4M3 = 0, // OCP E4M3
|
||||||
|
E5M2 = 1, // OCP E5M2
|
||||||
|
E4M3_FNUZ = 2, // Standard FP8
|
||||||
|
E5M2_FNUZ = 3, // BF8
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef enum fp8_variant_t {
|
||||||
|
_E4M3 = 0, // OCP E4M3
|
||||||
|
_E5M2 = 1, // OCP E5M2
|
||||||
|
_E4M3_FNUZ = 2, // Standard FP8
|
||||||
|
_E5M2_FNUZ = 3, // BF8
|
||||||
|
} fp8_variant_t;
|
||||||
|
|
||||||
|
|
||||||
|
template<int variant>
|
||||||
|
struct _fp8_variant_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
|
||||||
|
constexpr static constant bool is_fnuz = false;
|
||||||
|
constexpr static constant int we = 4;
|
||||||
|
constexpr static constant int wm = 3;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
|
||||||
|
constexpr static constant bool is_fnuz = false;
|
||||||
|
constexpr static constant int we = 5;
|
||||||
|
constexpr static constant int wm = 2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
constexpr static constant bool is_fnuz = true;
|
||||||
|
constexpr static constant int we = 4;
|
||||||
|
constexpr static constant int wm = 3;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
constexpr static constant bool is_fnuz = true;
|
||||||
|
constexpr static constant int we = 5;
|
||||||
|
constexpr static constant int wm = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T, int variant>
|
||||||
|
struct _fp8_variant_cast_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x43700000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x5F00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x5B80;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Simplify. No need to support all fp8 variants immediately.
|
||||||
|
template <typename T, int variant = _fp8_variant_t::E4M3_FNUZ>
|
||||||
|
METAL_FUNC fp8_storage_t cast_to_fp8(T _x, bool clip = false, bool stoch = false, uint rng = 0) {
|
||||||
|
typedef _fp_encoding_traits<T> traits;
|
||||||
|
typedef typename traits::encoding_type bits;
|
||||||
|
typedef numeric_limits<bits> limits;
|
||||||
|
|
||||||
|
typedef _fp8_cast_traits<T> cast_traits;
|
||||||
|
typedef _fp8_variant_traits<variant> variant_traits;
|
||||||
|
typedef _fp8_variant_cast_traits<T, variant> variant_cast_traits;
|
||||||
|
|
||||||
|
constexpr bool is_fnuz = variant_traits::is_fnuz;
|
||||||
|
constexpr int we = variant_traits::we;
|
||||||
|
constexpr int wm = variant_traits::wm;
|
||||||
|
constexpr int mfmt = traits::exponent_shift;
|
||||||
|
constexpr int bias = traits::exponent_bias;
|
||||||
|
constexpr bits mask = cast_traits::mask;
|
||||||
|
|
||||||
|
bits x = as_type<bits>(_x);
|
||||||
|
|
||||||
|
bits head = x & cast_traits::head_mask;
|
||||||
|
bits mantissa = x & cast_traits::mantissa_mask;
|
||||||
|
int exponent = (head >> traits::exponent_shift) & traits::exponent_max;
|
||||||
|
bits sign = head >> traits::sign_shift;
|
||||||
|
|
||||||
|
bits signed_inf = 0;
|
||||||
|
unsigned int nan = 0;
|
||||||
|
if (is_fnuz) {
|
||||||
|
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||||
|
nan = 0x80;
|
||||||
|
} else {
|
||||||
|
if (we == 4) {
|
||||||
|
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||||
|
} else {
|
||||||
|
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||||
|
}
|
||||||
|
nan = (sign << 7) + 0x7f;
|
||||||
|
}
|
||||||
|
constexpr bits ifmax = variant_cast_traits::ifmax;
|
||||||
|
|
||||||
|
// Deal with inf and NaNs
|
||||||
|
if ((x & traits::inf_mask) == traits::inf_mask) {
|
||||||
|
if (is_fnuz || we == 4) return nan; // fnuz and OCP E4M3 has no INF
|
||||||
|
if (mantissa != 0) return nan; // NaN
|
||||||
|
return sign == 0 ? 0x7C : 0xFC; // E5M2 Inf
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((x & mask) > ifmax) {
|
||||||
|
return signed_inf;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
|
||||||
|
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||||
|
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||||
|
|
||||||
|
int act_exponent, f8_exponent, exponent_diff;
|
||||||
|
|
||||||
|
if (exponent == 0) {
|
||||||
|
act_exponent = exponent - bias + 1;
|
||||||
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
|
} else {
|
||||||
|
act_exponent = exponent - bias;
|
||||||
|
if (act_exponent <= f8_denormal_act_exponent) {
|
||||||
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
|
} else {
|
||||||
|
exponent_diff = 0;
|
||||||
|
}
|
||||||
|
mantissa += (1ull << mfmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||||
|
(1ull << (mfmt - wm + exponent_diff - 1));
|
||||||
|
|
||||||
|
if (exponent_diff > 0)
|
||||||
|
mantissa >>= exponent_diff;
|
||||||
|
else if (exponent_diff == -1)
|
||||||
|
mantissa <<= -exponent_diff;
|
||||||
|
bool implicit_one = mantissa & (1ull << mfmt);
|
||||||
|
f8_exponent =
|
||||||
|
(act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
|
||||||
|
|
||||||
|
unsigned long drop_mask = (1ull << (mfmt - wm)) - 1;
|
||||||
|
bool odd =
|
||||||
|
mantissa & (1ull << (mfmt - wm));
|
||||||
|
mantissa +=
|
||||||
|
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||||
|
|
||||||
|
if (f8_exponent == 0) {
|
||||||
|
if ((1ull << mfmt) & mantissa) {
|
||||||
|
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((1ull << (mfmt + 1)) & mantissa) {
|
||||||
|
mantissa >>= 1;
|
||||||
|
f8_exponent++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mantissa >>= (mfmt - wm);
|
||||||
|
const int max_exp = (1 << we) - 1;
|
||||||
|
if (f8_exponent > max_exp) {
|
||||||
|
if (clip) {
|
||||||
|
mantissa = (1 << wm) - 1;
|
||||||
|
f8_exponent = max_exp;
|
||||||
|
} else {
|
||||||
|
return signed_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
|
||||||
|
mantissa &= (1 << wm) - 1;
|
||||||
|
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int variant>
|
||||||
|
METAL_FUNC half fp8_to_half(fp8_storage_t x);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
typedef typename traits::encoding_type bits;
|
||||||
|
|
||||||
|
bits ur = x << 8U;
|
||||||
|
bits sign = ur & 0x8000U;
|
||||||
|
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
|
||||||
|
bits mantissa = (ur & 0x0700U) >> 1U;
|
||||||
|
unsigned char absx = 0x7FU & (unsigned char)x;
|
||||||
|
|
||||||
|
if (absx == 0x7FU) {
|
||||||
|
// return NaN
|
||||||
|
ur = 0x7FFFU;
|
||||||
|
} else if (exponent == 0x2000U) {
|
||||||
|
if (mantissa != 0U) {
|
||||||
|
// normalize
|
||||||
|
mantissa = (bits)(mantissa << 1U);
|
||||||
|
while ((mantissa & 0x0400U) == 0U) {
|
||||||
|
mantissa = (bits)(mantissa << 1U);
|
||||||
|
exponent = (bits)(exponent - 0x0400U);
|
||||||
|
}
|
||||||
|
// discard implicit leading bit
|
||||||
|
mantissa &= 0x03FFU;
|
||||||
|
} else {
|
||||||
|
// zero
|
||||||
|
exponent = 0U;
|
||||||
|
}
|
||||||
|
|
||||||
|
ur = (sign | exponent) | mantissa;
|
||||||
|
} else {
|
||||||
|
ur = (sign | exponent) | mantissa;
|
||||||
|
}
|
||||||
|
|
||||||
|
return as_type<half>(ur);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
typedef typename traits::encoding_type bits;
|
||||||
|
|
||||||
|
bits ur = x << 8U;
|
||||||
|
if ((x & 0x7FFFU) > 0x7C00U) {
|
||||||
|
// return NaN
|
||||||
|
ur = 0x7FFFU;
|
||||||
|
}
|
||||||
|
return as_type<half>(ur);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int variant>
|
||||||
|
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
|
||||||
|
return static_cast<T>(fp8_to_half<variant>(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CAST_TO_FP8(name, T) \
|
||||||
|
CAST_TO_FP8_VARIANT(name##_E4M3, T, E4M3) \
|
||||||
|
CAST_TO_FP8_VARIANT(name##_E5M2, T, E5M2)
|
||||||
|
|
||||||
|
#define CAST_TO_FP8_VARIANT(name, T, FP8_VARIANT) \
|
||||||
|
kernel void name( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
device const T *input, \
|
||||||
|
device fp8_storage_t *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[tid] = cast_to_fp8<T, FP8_VARIANT>(input[tid]); \
|
||||||
|
} \
|
||||||
|
kernel void name##_strided( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
device const T *input, \
|
||||||
|
device fp8_storage_t *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[tid] = cast_to_fp8<T, FP8_VARIANT>( \
|
||||||
|
input[get_strided_index(tid, num_dims, dims, strides)] \
|
||||||
|
); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define CAST_FROM_FP8(name, T) \
|
||||||
|
CAST_FROM_FP8_VARIANT(name##_E4M3, T, E4M3) \
|
||||||
|
CAST_FROM_FP8_VARIANT(name##_E5M2, T, E5M2)
|
||||||
|
|
||||||
|
#define CAST_FROM_FP8_VARIANT(name, T, FP8_VARIANT) \
|
||||||
|
kernel void name( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
device const fp8_storage_t *input, \
|
||||||
|
device T *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[tid] = cast_fp8_to<T, FP8_VARIANT>(input[tid]); \
|
||||||
|
} \
|
||||||
|
kernel void name##_strided( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
device const fp8_storage_t *input, \
|
||||||
|
device T *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[tid] = cast_fp8_to<T, FP8_VARIANT>( \
|
||||||
|
input[get_strided_index(tid, num_dims, dims, strides)] \
|
||||||
|
); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
CAST_FROM_FP8(cast_fp8_f16, half)
|
||||||
|
CAST_FROM_FP8(cast_fp8_f32, float)
|
||||||
|
CAST_TO_FP8(cast_f32_fp8, float)
|
||||||
|
CAST_TO_FP8(cast_f16_fp8, half)
|
||||||
|
|
||||||
|
|
||||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -128,4 +500,4 @@ CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
|||||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||||
#endif
|
#endif
|
||||||
|
@ -37,12 +37,19 @@ pub enum DType {
|
|||||||
I64,
|
I64,
|
||||||
U32,
|
U32,
|
||||||
U8,
|
U8,
|
||||||
|
FP8(FP8Variant),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum FP8Variant {
|
||||||
|
E4M3,
|
||||||
|
E5M2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DType {
|
impl DType {
|
||||||
fn size_in_bytes(&self) -> usize {
|
fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::U8 => 1,
|
Self::U8 | Self::FP8(_) => 1,
|
||||||
Self::U32 => 4,
|
Self::U32 => 4,
|
||||||
Self::I64 => 8,
|
Self::I64 => 8,
|
||||||
Self::BF16 => 2,
|
Self::BF16 => 2,
|
||||||
@ -311,7 +318,12 @@ impl Kernels {
|
|||||||
let source_content = self.get_library_source(source);
|
let source_content = self.get_library_source(source);
|
||||||
device
|
device
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
.map_err(|e| {
|
||||||
|
// Makes metal errors easier to read.
|
||||||
|
// TODO: Remove.
|
||||||
|
panic!("{}", e);
|
||||||
|
return MetalKernelError::LoadLibraryError(e.to_string());
|
||||||
|
})?
|
||||||
};
|
};
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
|
@ -11,6 +11,186 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
// FP8 casting
|
||||||
|
|
||||||
|
enum _fp8_variant_t {
|
||||||
|
E4M3 = 0, // OCP E4M3
|
||||||
|
E5M2 = 1, // OCP E5M2
|
||||||
|
E4M3_FNUZ = 2, // Standard FP8
|
||||||
|
E5M2_FNUZ = 3, // BF8
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int variant>
|
||||||
|
struct _fp8_variant_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E4M3> {
|
||||||
|
constexpr static constant bool is_fnuz = false;
|
||||||
|
constexpr static constant int we = 4;
|
||||||
|
constexpr static constant int wm = 3;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E5M2> {
|
||||||
|
constexpr static constant bool is_fnuz = false;
|
||||||
|
constexpr static constant int we = 5;
|
||||||
|
constexpr static constant int wm = 2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
constexpr static constant bool is_fnuz = true;
|
||||||
|
constexpr static constant int we = 4;
|
||||||
|
constexpr static constant int wm = 3;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
constexpr static constant bool is_fnuz = true;
|
||||||
|
constexpr static constant int we = 5;
|
||||||
|
constexpr static constant int wm = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct _fp8_cast_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_cast_traits<float> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
typedef typename traits::encoding_type encoding_type;
|
||||||
|
constexpr static constant encoding_type head_mask = 0xFF800000;
|
||||||
|
constexpr static constant encoding_type mantissa_mask = 0x7FFFFF;
|
||||||
|
constexpr static constant encoding_type mask = 0x7FFFFFFF;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_cast_traits<half> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
typedef typename traits::encoding_type encoding_type;
|
||||||
|
constexpr static constant encoding_type head_mask = 0xFC00;
|
||||||
|
constexpr static constant encoding_type mantissa_mask = 0x3FF;
|
||||||
|
constexpr static constant encoding_type mask = 0x7FFF;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T, int variant>
|
||||||
|
struct _fp8_variant_cast_traits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x43E00000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x43700000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<float, _fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<float> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x47600000;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xC7600000;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x5F00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x5B80;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0x0; // unused
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct _fp8_variant_cast_traits<half, _fp8_variant_t::E5M2_FNUZ> {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
constexpr static constant traits::encoding_type ifmax = 0x7B00;
|
||||||
|
constexpr static constant traits::encoding_type ifmin = 0xFB00;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef unsigned char fp8_storage_t;
|
||||||
|
typedef unsigned short int fp8x2_storage_t;
|
||||||
|
typedef unsigned int fp8x4_storage_t;
|
||||||
|
|
||||||
|
template <int variant>
|
||||||
|
METAL_FUNC half fp8_to_half(fp8_storage_t x);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
typedef _fp8_cast_traits<half> cast_traits;
|
||||||
|
typedef typename traits::encoding_type bits;
|
||||||
|
typedef numeric_limits<bits> limits;
|
||||||
|
|
||||||
|
typedef _fp8_variant_traits<_fp8_variant_t::E4M3> variant_traits;
|
||||||
|
typedef _fp8_variant_cast_traits<half, _fp8_variant_t::E4M3> variant_cast_traits;
|
||||||
|
|
||||||
|
bits ur = x << 8U;
|
||||||
|
bits sign = ur & 0x8000U;
|
||||||
|
bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U);
|
||||||
|
bits mantissa = (ur & 0x0700U) >> 1U;
|
||||||
|
unsigned char absx = 0x7FU & (unsigned char)x;
|
||||||
|
|
||||||
|
if (absx == 0x7FU) { // NaN
|
||||||
|
ur = 0x7FFFU; // fp16 canonical NaN, discard sign
|
||||||
|
} else if (exponent == 0x2000U) {
|
||||||
|
// zero or denormal
|
||||||
|
if (mantissa != 0U) {
|
||||||
|
// normalize
|
||||||
|
mantissa = (bits)(mantissa << 1U);
|
||||||
|
while ((mantissa & 0x0400U) == 0U) {
|
||||||
|
mantissa = (bits)(mantissa << 1U);
|
||||||
|
exponent = (bits)(exponent - 0x0400U);
|
||||||
|
}
|
||||||
|
// discard implicit leading bit
|
||||||
|
mantissa &= 0x03FFU;
|
||||||
|
} else { // Zero
|
||||||
|
exponent = 0U;
|
||||||
|
}
|
||||||
|
|
||||||
|
ur = (sign | exponent) | mantissa;
|
||||||
|
} else {
|
||||||
|
ur = (sign | exponent) | mantissa;
|
||||||
|
}
|
||||||
|
|
||||||
|
return as_type<half>(ur);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) {
|
||||||
|
typedef _fp_encoding_traits<half> traits;
|
||||||
|
typedef _fp8_cast_traits<half> cast_traits;
|
||||||
|
typedef typename traits::encoding_type bits;
|
||||||
|
typedef numeric_limits<bits> limits;
|
||||||
|
|
||||||
|
bits ur = x << 8U;
|
||||||
|
if ((x & 0x7FFFU) > 0x7C00U) {
|
||||||
|
/* If NaN, return canonical NaN */
|
||||||
|
ur = 0x7FFFU;
|
||||||
|
}
|
||||||
|
return as_type<half>(ur);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int variant>
|
||||||
|
METAL_FUNC T cast_fp8_to(fp8_storage_t x) {
|
||||||
|
return static_cast<T>(fp8_to_half<variant>(x));
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1
|
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// GEMM param classes
|
// GEMM param classes
|
||||||
@ -199,14 +379,76 @@ struct BlockLoader {
|
|||||||
// Transforms and Epilogues
|
// Transforms and Epilogues
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename OutT, typename InT>
|
||||||
|
METAL_FUNC OutT mlx_cast(InT x);
|
||||||
|
|
||||||
|
template<>
|
||||||
|
METAL_FUNC float mlx_cast<float, float>(float x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
METAL_FUNC half mlx_cast<half, half>(half x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
METAL_FUNC float mlx_cast<float, half>(half x) {
|
||||||
|
return static_cast<float>(x);
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
METAL_FUNC half mlx_cast<half, float>(float x) {
|
||||||
|
return static_cast<half>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
template<>
|
||||||
|
METAL_FUNC bfloat mlx_cast<bfloat, bfloat>(bfloat x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
METAL_FUNC float mlx_cast<float, bfloat>(bfloat x) {
|
||||||
|
return static_cast<float>(x);
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
METAL_FUNC bfloat mlx_cast<bfloat, float>(float x) {
|
||||||
|
return static_cast<bfloat>(x);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
METAL_FUNC fp8_storage_t mlx_cast<fp8_storage_t, fp8_storage_t>(fp8_storage_t x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
METAL_FUNC float mlx_cast<float, fp8_storage_t>(fp8_storage_t x) {
|
||||||
|
return cast_fp8_to<float, E5M2>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
//template<>
|
||||||
|
//METAL_FUNC fp8_storage_t mlx_cast<fp8_storage_t, float>(float x) {
|
||||||
|
// return cast_to_f8<float, E5M2>(x);
|
||||||
|
//}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
METAL_FUNC half mlx_cast<half, fp8_storage_t>(fp8_storage_t x) {
|
||||||
|
return cast_fp8_to<half, E5M2>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
//template<>
|
||||||
|
//METAL_FUNC fp8_storage_t mlx_cast<fp8_storage_t, half>(half x) {
|
||||||
|
// return cast_to_f8<float, E5M2>(x);
|
||||||
|
//}
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
template <typename OutT, typename InT>
|
||||||
struct TransformNone {
|
struct TransformNone {
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
return static_cast<OutT>(x);
|
return mlx_cast<OutT>(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||||
return static_cast<OutT>(x);
|
return mlx_cast<OutT>(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -215,11 +457,11 @@ struct TransformAdd {
|
|||||||
TransformAdd(const float, const float) {}
|
TransformAdd(const float, const float) {}
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
return static_cast<OutT>(x);
|
return mlx_cast<OutT>(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||||
return static_cast<OutT>(x) + c;
|
return mlx_cast<OutT>(x) + c;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -232,11 +474,11 @@ struct TransformAxpby {
|
|||||||
: alpha(alpha_), beta(beta_) {}
|
: alpha(alpha_), beta(beta_) {}
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
return static_cast<OutT>(x);
|
return mlx_cast<OutT>(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||||
return static_cast<OutT>(x * alpha + (beta * c));
|
return mlx_cast<OutT>(x * alpha + (beta * c));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -336,53 +578,47 @@ struct BlockMMA {
|
|||||||
// Adjust for simdgroup and thread location
|
// Adjust for simdgroup and thread location
|
||||||
As += As_offset;
|
As += As_offset;
|
||||||
Bs += Bs_offset;
|
Bs += Bs_offset;
|
||||||
|
|
||||||
// Iterate over BK in blocks of 8
|
// Iterate over BK in blocks of 8
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short kk = 0; kk < BK; kk += 8) {
|
for (short kk = 0; kk < BK; kk += 8) {
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
// Load elements from threadgroup A as simdgroup matrices
|
// Load elements from threadgroup A as simdgroup matrices
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < TM; i++) {
|
for (short i = 0; i < TM; i++) {
|
||||||
Asimd[i].thread_elements()[0] =
|
Asimd[i].thread_elements()[0] = mlx_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
Asimd[i].thread_elements()[1] = mlx_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||||
Asimd[i].thread_elements()[1] =
|
}
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
// Load elements from threadgroup B as simdgroup matrices
|
||||||
|
|
||||||
// Load elements from threadgroup B as simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
Bsimd[j].thread_elements()[0] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
|
||||||
Bsimd[j].thread_elements()[1] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Multiply and accumulate into result simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short j = 0; j < TN; j++) {
|
for (short j = 0; j < TN; j++) {
|
||||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
Bsimd[j].thread_elements()[0] = mlx_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||||
|
Bsimd[j].thread_elements()[1] = mlx_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||||
simdgroup_multiply_accumulate(
|
|
||||||
results[i * TN + j_serp],
|
|
||||||
Asimd[i],
|
|
||||||
Bsimd[j_serp],
|
|
||||||
results[i * TN + j_serp]);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Progress to next simdgroup tile
|
|
||||||
As += tile_stride_a;
|
|
||||||
Bs += tile_stride_b;
|
|
||||||
}
|
}
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Multiply and accumulate into result simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||||
|
simdgroup_multiply_accumulate(
|
||||||
|
results[i * TN + j_serp],
|
||||||
|
Asimd[i],
|
||||||
|
Bsimd[j_serp],
|
||||||
|
results[i * TN + j_serp]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Progress to next simdgroup tile
|
||||||
|
As += tile_stride_a;
|
||||||
|
Bs += tile_stride_b;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
@ -1020,12 +1256,13 @@ template <
|
|||||||
int WN,
|
int WN,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
typename AccumType = float>
|
typename AccumType = float,
|
||||||
|
typename ResultType = T>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
const device ResultType* C [[buffer(2), function_constant(use_out_source)]],
|
||||||
device T* D [[buffer(3)]],
|
device ResultType* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
@ -1045,7 +1282,7 @@ template <
|
|||||||
|
|
||||||
using gemm_kernel = GEMMKernel<
|
using gemm_kernel = GEMMKernel<
|
||||||
T,
|
T,
|
||||||
T,
|
ResultType,
|
||||||
BM,
|
BM,
|
||||||
BN,
|
BN,
|
||||||
BK,
|
BK,
|
||||||
@ -1326,7 +1563,7 @@ template <
|
|||||||
addmm_params->fdc,
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op_add);
|
epilogue_op_add);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
@ -1396,7 +1633,7 @@ template <
|
|||||||
addmm_params->fdc,
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op_add);
|
epilogue_op_add);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
@ -1405,13 +1642,13 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \
|
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float, stype>( \
|
||||||
const device itype *A [[buffer(0)]], \
|
const device itype *A [[buffer(0)]], \
|
||||||
const device itype *B [[buffer(1)]], \
|
const device itype *B [[buffer(1)]], \
|
||||||
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
|
const device stype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||||
device itype *D [[buffer(3)]], \
|
device stype *D [[buffer(3)]], \
|
||||||
const constant GEMMParams* params [[buffer(4)]], \
|
const constant GEMMParams* params [[buffer(4)]], \
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
||||||
const constant int* batch_shape [[buffer(6)]], \
|
const constant int* batch_shape [[buffer(6)]], \
|
||||||
@ -1427,14 +1664,17 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \
|
||||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype)
|
||||||
|
|
||||||
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2, float)
|
||||||
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2, half)
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_gemm_transpose_helper(F8E5M2, fp8_storage_t, f16, half, 32, 32, 16, 2, 2, half)
|
||||||
|
@ -5,6 +5,7 @@ use std::ffi::c_void;
|
|||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||||
pub enum GemmDType {
|
pub enum GemmDType {
|
||||||
|
F8E5M2,
|
||||||
BF16,
|
BF16,
|
||||||
F16,
|
F16,
|
||||||
F32,
|
F32,
|
||||||
@ -138,6 +139,10 @@ pub fn call_mlx_gemm(
|
|||||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F8E5M2, false, false) => "gemm_nn_F8E5M2_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F8E5M2, true, false) => "gemm_tn_F8E5M2_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F8E5M2, false, true) => "gemm_nt_F8E5M2_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F8E5M2, true, true) => "gemm_tt_F8E5M2_f16_32_32_16_2_2",
|
||||||
};
|
};
|
||||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
|
@ -41,6 +41,7 @@ pub fn call_arg_sort(
|
|||||||
|
|
||||||
fn mlx_dtype_str(dtype: DType) -> &'static str {
|
fn mlx_dtype_str(dtype: DType) -> &'static str {
|
||||||
match dtype {
|
match dtype {
|
||||||
|
DType::FP8(_) => todo!(),
|
||||||
DType::U8 => "uint8",
|
DType::U8 => "uint8",
|
||||||
DType::U32 => "uint32",
|
DType::U32 => "uint32",
|
||||||
DType::I64 => "int64",
|
DType::I64 => "int64",
|
||||||
@ -195,6 +196,7 @@ pub fn multi_block_sort(
|
|||||||
};
|
};
|
||||||
// Copy output with appropriate strides
|
// Copy output with appropriate strides
|
||||||
let copy_kernel = match dtype {
|
let copy_kernel = match dtype {
|
||||||
|
DType::FP8(_) => todo!(),
|
||||||
DType::U8 => crate::copy2d::U8,
|
DType::U8 => crate::copy2d::U8,
|
||||||
DType::U32 => crate::copy2d::U32,
|
DType::U32 => crate::copy2d::U32,
|
||||||
DType::I64 => crate::copy2d::I64,
|
DType::I64 => crate::copy2d::I64,
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
use float8::{F8E4M3, F8E5M2};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal::{Buffer, Device, MTLResourceOptions};
|
use metal::{Buffer, Device, MTLResourceOptions};
|
||||||
use rand::prelude::SliceRandom;
|
use rand::prelude::SliceRandom;
|
||||||
@ -423,6 +425,38 @@ fn cast_bf16() {
|
|||||||
assert_eq!(results, v_i64);
|
assert_eq!(results, v_i64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cast_fp8() {
|
||||||
|
let v_f64 = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||||
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||||
|
|
||||||
|
let v_f8: Vec<F8E4M3> = v_f32.iter().map(|&v| F8E4M3::from(v)).collect();
|
||||||
|
let v_f16_rt: Vec<f16> = v_f8.iter().map(|&v| v.to_f16()).collect();
|
||||||
|
|
||||||
|
// E4M3
|
||||||
|
// f32 -> fp8 -> f32
|
||||||
|
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E4M3");
|
||||||
|
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E4M3");
|
||||||
|
assert_eq!(v_f32, roundtrip);
|
||||||
|
|
||||||
|
// f16 -> fp8 -> f16
|
||||||
|
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E4M3");
|
||||||
|
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E4M3");
|
||||||
|
assert_eq!(v_f16, roundtrip);
|
||||||
|
|
||||||
|
// E5M2
|
||||||
|
// f16 -> fp8 -> f16
|
||||||
|
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_fp8_E5M2");
|
||||||
|
let roundtrip: Vec<f16> = run_cast(&results, "cast_fp8_f16_E5M2");
|
||||||
|
assert_eq!(v_f16, roundtrip);
|
||||||
|
|
||||||
|
// f32 -> fp8 -> f32
|
||||||
|
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_fp8_E5M2");
|
||||||
|
let roundtrip: Vec<f32> = run_cast(&results, "cast_fp8_f32_E5M2");
|
||||||
|
assert_eq!(v_f32, roundtrip);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cast_u32() {
|
fn cast_u32() {
|
||||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||||
@ -1285,7 +1319,7 @@ fn where_cond_u32_f32() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_mlx_gemm<T: Clone>(
|
fn run_mlx_gemm<T: Clone, U: Clone>(
|
||||||
dtype: GemmDType,
|
dtype: GemmDType,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
@ -1294,7 +1328,7 @@ fn run_mlx_gemm<T: Clone>(
|
|||||||
rhs: &[T],
|
rhs: &[T],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
rhs_offset: usize,
|
rhs_offset: usize,
|
||||||
) -> Vec<T> {
|
) -> Vec<U> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
@ -1436,6 +1470,33 @@ fn mlx_gemm() {
|
|||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// F8E5M2 sanity test
|
||||||
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
let lhs: Vec<u8> = (0..b * m * k)
|
||||||
|
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
|
||||||
|
.collect();
|
||||||
|
let rhs: Vec<u8> = (0..b * n * k)
|
||||||
|
.map(|f| F8E5M2::from_f32(f as f32).to_bits())
|
||||||
|
.collect();
|
||||||
|
let results: Vec<f16> = run_mlx_gemm(
|
||||||
|
GemmDType::F8E5M2,
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
&[m * k, k, 1],
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
&[n * k, n, 1],
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
println!("{results:?}");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
approx_f16(results, 4),
|
||||||
|
vec![20.0, 21.0, 26.0, 31.0, 56.0, 63.0, 80.0, 97.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||||
|
Reference in New Issue
Block a user