diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 9cb1cf8b..a3252914 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -3,8 +3,8 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - benchmarks::affine::benches, benchmarks::matmul::benches, + benchmarks::affine::benches, benchmarks::random::benches, benchmarks::reduce::benches, benchmarks::where_cond::benches, diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs index 9d67e642..b8073485 100644 --- a/candle-core/benches/benchmarks/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -7,19 +7,25 @@ fn run(a: &Tensor, b: &Tensor) { a.matmul(&b.t().unwrap()).unwrap(); } -fn run_bench(c: &mut Criterion, device: &Device) { +fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) { let b = 1; let m = 1; let n = 2048; let k = 2048; - let dtype = DType::F32; let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let flops = b * m * n * k; - let mut group = c.benchmark_group(device.bench_name("matmul")); + let name = match dtype { + DType::F32 => "matmul_f32", + DType::F16 => "matmul_f16", + DType::BF16 => "matmul_bf16", + DType::U8 => "matmul_fp8", + _ => unimplemented!("{dtype:?} matmul bench not implemented"), + }; + let mut group = c.benchmark_group(device.bench_name(name)); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -36,8 +42,11 @@ fn run_bench(c: &mut Criterion, device: &Device) { fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); + let dtypes = vec![DType::F32, DType::F16, DType::BF16, DType::U8]; for device in handler.devices { - run_bench(c, &device); + for dtype in dtypes.clone() { + run_bench(c, &device, dtype); + } } } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 433188cf..378b5f47 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1513,50 +1513,35 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); - if self.dtype == DType::BF16 { - candle_metal_kernels::call_mlx_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - candle_metal_kernels::GemmDType::BF16, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } else { - let dtype = match self.dtype { - DType::F32 => candle_metal_kernels::GemmDType::F32, - DType::F16 => candle_metal_kernels::GemmDType::F16, - DType::BF16 => candle_metal_kernels::GemmDType::BF16, - dtype => { - return Err(MetalError::Message(format!( - "mlx matmul doesn't support {dtype:?}" - )) - .into()) - } - }; - candle_metal_kernels::call_mlx_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - dtype, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } + + let dtype = match self.dtype { + // Hijacking the U8 dtype to represent E5M2 fp8 + DType::U8 => candle_metal_kernels::GemmDType::F8E5M2, + DType::F32 => candle_metal_kernels::GemmDType::F32, + DType::F16 => candle_metal_kernels::GemmDType::F16, + DType::BF16 => candle_metal_kernels::GemmDType::BF16, + dtype => { + return Err( + MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(), + ) + } + }; + candle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + dtype, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new( buffer, self.device.clone(), diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 156a1962..e69f4ff3 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0" [dependencies] +float8 = "0.2.1" metal = { version = "0.27.0", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 2af3fdce..bde7383f 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -1,4 +1,5 @@ #include +#include METAL_FUNC uint get_strided_index( uint idx, @@ -18,6 +19,377 @@ METAL_FUNC uint get_strided_index( using namespace metal; +typedef unsigned char fp8_storage_t; +typedef unsigned short int fp8x2_storage_t; +typedef unsigned int fp8x4_storage_t; + +template +struct _fp8_cast_traits; + +template <> +struct _fp8_cast_traits { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type encoding_type; + constexpr static constant encoding_type head_mask = 0xFF800000; + constexpr static constant encoding_type mantissa_mask = 0x7FFFFF; + constexpr static constant encoding_type mask = 0x7FFFFFFF; +}; + +template <> +struct _fp8_cast_traits { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type encoding_type; + constexpr static constant encoding_type head_mask = 0xFC00; + constexpr static constant encoding_type mantissa_mask = 0x3FF; + constexpr static constant encoding_type mask = 0x7FFF; +}; + +enum _fp8_variant_t { + E4M3 = 0, // OCP E4M3 + E5M2 = 1, // OCP E5M2 + E4M3_FNUZ = 2, // Standard FP8 + E5M2_FNUZ = 3, // BF8 +}; + + +typedef enum fp8_variant_t { + _E4M3 = 0, // OCP E4M3 + _E5M2 = 1, // OCP E5M2 + _E4M3_FNUZ = 2, // Standard FP8 + _E5M2_FNUZ = 3, // BF8 +} fp8_variant_t; + + +template +struct _fp8_variant_traits; + +template <> +struct _fp8_variant_traits<_fp8_variant_t::E4M3> { + constexpr static constant bool is_fnuz = false; + constexpr static constant int we = 4; + constexpr static constant int wm = 3; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E5M2> { + constexpr static constant bool is_fnuz = false; + constexpr static constant int we = 5; + constexpr static constant int wm = 2; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> { + constexpr static constant bool is_fnuz = true; + constexpr static constant int we = 4; + constexpr static constant int wm = 3; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> { + constexpr static constant bool is_fnuz = true; + constexpr static constant int we = 5; + constexpr static constant int wm = 2; +}; + + +template +struct _fp8_variant_cast_traits; + +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x43E00000; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x47600000; + constexpr static constant traits::encoding_type ifmin = 0xC7600000; +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x43700000; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x47600000; + constexpr static constant traits::encoding_type ifmin = 0xC7600000; +}; + +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x5F00; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x7B00; + constexpr static constant traits::encoding_type ifmin = 0xFB00; +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x5B80; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x7B00; + constexpr static constant traits::encoding_type ifmin = 0xFB00; +}; + +// TODO: Simplify. No need to support all fp8 variants immediately. +template +METAL_FUNC fp8_storage_t cast_to_fp8(T _x, bool clip = false, bool stoch = false, uint rng = 0) { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type bits; + typedef numeric_limits limits; + + typedef _fp8_cast_traits cast_traits; + typedef _fp8_variant_traits variant_traits; + typedef _fp8_variant_cast_traits variant_cast_traits; + + constexpr bool is_fnuz = variant_traits::is_fnuz; + constexpr int we = variant_traits::we; + constexpr int wm = variant_traits::wm; + constexpr int mfmt = traits::exponent_shift; + constexpr int bias = traits::exponent_bias; + constexpr bits mask = cast_traits::mask; + + bits x = as_type(_x); + + bits head = x & cast_traits::head_mask; + bits mantissa = x & cast_traits::mantissa_mask; + int exponent = (head >> traits::exponent_shift) & traits::exponent_max; + bits sign = head >> traits::sign_shift; + + bits signed_inf = 0; + unsigned int nan = 0; + if (is_fnuz) { + signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + nan = 0x80; + } else { + if (we == 4) { + signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + } else { + signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + } + nan = (sign << 7) + 0x7f; + } + constexpr bits ifmax = variant_cast_traits::ifmax; + + // Deal with inf and NaNs + if ((x & traits::inf_mask) == traits::inf_mask) { + if (is_fnuz || we == 4) return nan; // fnuz and OCP E4M3 has no INF + if (mantissa != 0) return nan; // NaN + return sign == 0 ? 0x7C : 0xFC; // E5M2 Inf + } + + if ((x & mask) > ifmax) { + return signed_inf; + } + + if (x == 0) { + return 0; + } + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits + const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { + exponent_diff = 0; + } + mantissa += (1ull << mfmt); + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + + if (exponent_diff > 0) + mantissa >>= exponent_diff; + else if (exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + f8_exponent = + (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1); + + unsigned long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & (1ull << (mfmt - wm)); + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + if (f8_exponent == 0) { + if ((1ull << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1ull << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + const int max_exp = (1 << we) - 1; + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +METAL_FUNC half fp8_to_half(fp8_storage_t x); + +template <> +METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type bits; + + bits ur = x << 8U; + bits sign = ur & 0x8000U; + bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U); + bits mantissa = (ur & 0x0700U) >> 1U; + unsigned char absx = 0x7FU & (unsigned char)x; + + if (absx == 0x7FU) { + // return NaN + ur = 0x7FFFU; + } else if (exponent == 0x2000U) { + if (mantissa != 0U) { + // normalize + mantissa = (bits)(mantissa << 1U); + while ((mantissa & 0x0400U) == 0U) { + mantissa = (bits)(mantissa << 1U); + exponent = (bits)(exponent - 0x0400U); + } + // discard implicit leading bit + mantissa &= 0x03FFU; + } else { + // zero + exponent = 0U; + } + + ur = (sign | exponent) | mantissa; + } else { + ur = (sign | exponent) | mantissa; + } + + return as_type(ur); +} + +template <> +METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type bits; + + bits ur = x << 8U; + if ((x & 0x7FFFU) > 0x7C00U) { + // return NaN + ur = 0x7FFFU; + } + return as_type(ur); +} + +template +METAL_FUNC T cast_fp8_to(fp8_storage_t x) { + return static_cast(fp8_to_half(x)); +} + +#define CAST_TO_FP8(name, T) \ +CAST_TO_FP8_VARIANT(name##_E4M3, T, E4M3) \ +CAST_TO_FP8_VARIANT(name##_E5M2, T, E5M2) + +#define CAST_TO_FP8_VARIANT(name, T, FP8_VARIANT) \ +kernel void name( \ + constant size_t &dim, \ + device const T *input, \ + device fp8_storage_t *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = cast_to_fp8(input[tid]); \ +} \ +kernel void name##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const T *input, \ + device fp8_storage_t *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = cast_to_fp8( \ + input[get_strided_index(tid, num_dims, dims, strides)] \ + ); \ +} \ + +#define CAST_FROM_FP8(name, T) \ +CAST_FROM_FP8_VARIANT(name##_E4M3, T, E4M3) \ +CAST_FROM_FP8_VARIANT(name##_E5M2, T, E5M2) + +#define CAST_FROM_FP8_VARIANT(name, T, FP8_VARIANT) \ +kernel void name( \ + constant size_t &dim, \ + device const fp8_storage_t *input, \ + device T *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = cast_fp8_to(input[tid]); \ +} \ +kernel void name##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const fp8_storage_t *input, \ + device T *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = cast_fp8_to( \ + input[get_strided_index(tid, num_dims, dims, strides)] \ + ); \ +} \ + +CAST_FROM_FP8(cast_fp8_f16, half) +CAST_FROM_FP8(cast_fp8_f32, float) +CAST_TO_FP8(cast_f32_fp8, float) +CAST_TO_FP8(cast_f16_fp8, half) + + #define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -128,4 +500,4 @@ CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) -#endif \ No newline at end of file +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6de44f9c..84f1c9b8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -37,12 +37,19 @@ pub enum DType { I64, U32, U8, + FP8(FP8Variant), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum FP8Variant { + E4M3, + E5M2, } impl DType { fn size_in_bytes(&self) -> usize { match self { - Self::U8 => 1, + Self::U8 | Self::FP8(_) => 1, Self::U32 => 4, Self::I64 => 8, Self::BF16 => 2, @@ -311,7 +318,12 @@ impl Kernels { let source_content = self.get_library_source(source); device .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + .map_err(|e| { + // Makes metal errors easier to read. + // TODO: Remove. + panic!("{}", e); + return MetalKernelError::LoadLibraryError(e.to_string()); + })? }; libraries.insert(source, lib.clone()); Ok(lib) diff --git a/candle-metal-kernels/src/mlx_gemm.metal b/candle-metal-kernels/src/mlx_gemm.metal index 1b5cad92..aa7a6438 100644 --- a/candle-metal-kernels/src/mlx_gemm.metal +++ b/candle-metal-kernels/src/mlx_gemm.metal @@ -11,6 +11,186 @@ using namespace metal; +// FP8 casting + +enum _fp8_variant_t { + E4M3 = 0, // OCP E4M3 + E5M2 = 1, // OCP E5M2 + E4M3_FNUZ = 2, // Standard FP8 + E5M2_FNUZ = 3, // BF8 +}; + +template +struct _fp8_variant_traits; + +template <> +struct _fp8_variant_traits<_fp8_variant_t::E4M3> { + constexpr static constant bool is_fnuz = false; + constexpr static constant int we = 4; + constexpr static constant int wm = 3; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E5M2> { + constexpr static constant bool is_fnuz = false; + constexpr static constant int we = 5; + constexpr static constant int wm = 2; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E4M3_FNUZ> { + constexpr static constant bool is_fnuz = true; + constexpr static constant int we = 4; + constexpr static constant int wm = 3; +}; +template <> +struct _fp8_variant_traits<_fp8_variant_t::E5M2_FNUZ> { + constexpr static constant bool is_fnuz = true; + constexpr static constant int we = 5; + constexpr static constant int wm = 2; +}; + + +template +struct _fp8_cast_traits; + +template <> +struct _fp8_cast_traits { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type encoding_type; + constexpr static constant encoding_type head_mask = 0xFF800000; + constexpr static constant encoding_type mantissa_mask = 0x7FFFFF; + constexpr static constant encoding_type mask = 0x7FFFFFFF; +}; + +template <> +struct _fp8_cast_traits { + typedef _fp_encoding_traits traits; + typedef typename traits::encoding_type encoding_type; + constexpr static constant encoding_type head_mask = 0xFC00; + constexpr static constant encoding_type mantissa_mask = 0x3FF; + constexpr static constant encoding_type mask = 0x7FFF; +}; + +template +struct _fp8_variant_cast_traits; + +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x43E00000; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x47600000; + constexpr static constant traits::encoding_type ifmin = 0xC7600000; +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x43700000; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x47600000; + constexpr static constant traits::encoding_type ifmin = 0xC7600000; +}; + +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x5F00; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x7B00; + constexpr static constant traits::encoding_type ifmin = 0xFB00; +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x5B80; + constexpr static constant traits::encoding_type ifmin = 0x0; // unused +}; +template <> +struct _fp8_variant_cast_traits { + typedef _fp_encoding_traits traits; + constexpr static constant traits::encoding_type ifmax = 0x7B00; + constexpr static constant traits::encoding_type ifmin = 0xFB00; +}; + +typedef unsigned char fp8_storage_t; +typedef unsigned short int fp8x2_storage_t; +typedef unsigned int fp8x4_storage_t; + +template +METAL_FUNC half fp8_to_half(fp8_storage_t x); + +template <> +METAL_FUNC half fp8_to_half<_fp8_variant_t::E4M3>(fp8_storage_t x) { + typedef _fp_encoding_traits traits; + typedef _fp8_cast_traits cast_traits; + typedef typename traits::encoding_type bits; + typedef numeric_limits limits; + + typedef _fp8_variant_traits<_fp8_variant_t::E4M3> variant_traits; + typedef _fp8_variant_cast_traits variant_cast_traits; + + bits ur = x << 8U; + bits sign = ur & 0x8000U; + bits exponent = (bits)(((ur & 0x7800U) >> 1U) + 0x2000U); + bits mantissa = (ur & 0x0700U) >> 1U; + unsigned char absx = 0x7FU & (unsigned char)x; + + if (absx == 0x7FU) { // NaN + ur = 0x7FFFU; // fp16 canonical NaN, discard sign + } else if (exponent == 0x2000U) { + // zero or denormal + if (mantissa != 0U) { + // normalize + mantissa = (bits)(mantissa << 1U); + while ((mantissa & 0x0400U) == 0U) { + mantissa = (bits)(mantissa << 1U); + exponent = (bits)(exponent - 0x0400U); + } + // discard implicit leading bit + mantissa &= 0x03FFU; + } else { // Zero + exponent = 0U; + } + + ur = (sign | exponent) | mantissa; + } else { + ur = (sign | exponent) | mantissa; + } + + return as_type(ur); +} + +template <> +METAL_FUNC half fp8_to_half<_fp8_variant_t::E5M2>(fp8_storage_t x) { + typedef _fp_encoding_traits traits; + typedef _fp8_cast_traits cast_traits; + typedef typename traits::encoding_type bits; + typedef numeric_limits limits; + + bits ur = x << 8U; + if ((x & 0x7FFFU) > 0x7C00U) { + /* If NaN, return canonical NaN */ + ur = 0x7FFFU; + } + return as_type(ur); +} + +template +METAL_FUNC T cast_fp8_to(fp8_storage_t x) { + return static_cast(fp8_to_half(x)); +} + // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1 /////////////////////////////////////////////////////////////////////////////// // GEMM param classes @@ -199,6 +379,35 @@ struct BlockLoader { // Transforms and Epilogues /////////////////////////////////////////////////////////////////////////////// +template +METAL_FUNC OutT accum_cast(InT x); + +template<> +METAL_FUNC float accum_cast(float x) { + return x; +} +template<> +METAL_FUNC float accum_cast(half x) { + return static_cast(x); +} + +#if defined(__HAVE_BFLOAT__) +template<> +METAL_FUNC float accum_cast(bfloat x) { + return static_cast(x); +} +#endif + +template<> +METAL_FUNC float accum_cast(fp8_storage_t x) { + return cast_fp8_to(x); +} + +template<> +METAL_FUNC half accum_cast(fp8_storage_t x) { + return cast_fp8_to(x); +} + template struct TransformNone { static METAL_FUNC OutT apply(InT x) { @@ -210,6 +419,17 @@ struct TransformNone { } }; +template +struct TransformNone { + static METAL_FUNC OutT apply(fp8_storage_t x) { + return cast_fp8_to(x); + } + + static METAL_FUNC OutT apply(fp8_storage_t x, OutT) { + return cast_fp8_to(x); + } +}; + template struct TransformAdd { TransformAdd(const float, const float) {} @@ -223,6 +443,20 @@ struct TransformAdd { } }; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(fp8_storage_t x) { + return cast_fp8_to(x); + } + + static METAL_FUNC OutT apply(fp8_storage_t x, OutT c) { + return cast_fp8_to(x) + c; + } +}; + template struct TransformAxpby { const float alpha; @@ -240,6 +474,24 @@ struct TransformAxpby { } }; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(fp8_storage_t x) { + return cast_fp8_to(x); + } + + METAL_FUNC OutT apply(fp8_storage_t x, OutT c) const { + return cast_fp8_to(x) * alpha + (beta * c); + } +}; + template struct AccumHelper { typedef float accum_type; @@ -336,53 +588,47 @@ struct BlockMMA { // Adjust for simdgroup and thread location As += As_offset; Bs += Bs_offset; - // Iterate over BK in blocks of 8 STEEL_PRAGMA_UNROLL for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); + simdgroup_barrier(mem_flags::mem_none); - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = accum_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = accum_cast(As[i * simd_stride_a + jump_a]); + } + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { + // Load elements from threadgroup B as simdgroup matrices STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); + Bsimd[j].thread_elements()[0] = accum_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = accum_cast(Bs[j * simd_stride_b + jump_b]); } - } - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; } + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } /* Store results from simdgroup_matrix results into device memory */ @@ -1020,12 +1266,13 @@ template < int WN, bool transpose_a, bool transpose_b, - typename AccumType = float> + typename AccumType = float, + typename ResultType = T> [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], - const device T* C [[buffer(2), function_constant(use_out_source)]], - device T* D [[buffer(3)]], + const device ResultType* C [[buffer(2), function_constant(use_out_source)]], + device ResultType* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], @@ -1045,7 +1292,7 @@ template < using gemm_kernel = GEMMKernel< T, - T, + ResultType, BM, BN, BK, @@ -1326,7 +1573,7 @@ template < addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_add); - } + } } // Store results to device memory @@ -1396,7 +1643,7 @@ template < addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_add); - } + } } // Store results to device memory @@ -1405,13 +1652,13 @@ template < } } -#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \ template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \ - [[kernel]] void gemm( \ + [[kernel]] void gemm( \ const device itype *A [[buffer(0)]], \ const device itype *B [[buffer(1)]], \ - const device itype *C [[buffer(2), function_constant(use_out_source)]], \ - device itype *D [[buffer(3)]], \ + const device stype *C [[buffer(2), function_constant(use_out_source)]], \ + device stype *D [[buffer(3)]], \ const constant GEMMParams* params [[buffer(4)]], \ const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \ const constant int* batch_shape [[buffer(6)]], \ @@ -1427,14 +1674,17 @@ template < uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn, stype) -instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2) -instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2) +instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2, float) +instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2, half) #if defined(__HAVE_BFLOAT__) -instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2) +instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2, bfloat) #endif + + +instantiate_gemm_transpose_helper(F8E5M2, fp8_storage_t, f16, half, 32, 32, 16, 2, 2, half) diff --git a/candle-metal-kernels/src/mlx_gemm.rs b/candle-metal-kernels/src/mlx_gemm.rs index ee4292c3..b3daf3d5 100644 --- a/candle-metal-kernels/src/mlx_gemm.rs +++ b/candle-metal-kernels/src/mlx_gemm.rs @@ -5,6 +5,7 @@ use std::ffi::c_void; #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum GemmDType { + F8E5M2, BF16, F16, F32, @@ -138,6 +139,10 @@ pub fn call_mlx_gemm( (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2", (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2", (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2", + (GemmDType::F8E5M2, false, false) => "gemm_nn_F8E5M2_f16_32_32_16_2_2", + (GemmDType::F8E5M2, true, false) => "gemm_tn_F8E5M2_f16_32_32_16_2_2", + (GemmDType::F8E5M2, false, true) => "gemm_nt_F8E5M2_f16_32_32_16_2_2", + (GemmDType::F8E5M2, true, true) => "gemm_tt_F8E5M2_f16_32_32_16_2_2", }; let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?; let encoder = ep.encoder(); diff --git a/candle-metal-kernels/src/sort.rs b/candle-metal-kernels/src/sort.rs index e4140eb3..663462a0 100644 --- a/candle-metal-kernels/src/sort.rs +++ b/candle-metal-kernels/src/sort.rs @@ -41,6 +41,7 @@ pub fn call_arg_sort( fn mlx_dtype_str(dtype: DType) -> &'static str { match dtype { + DType::FP8(_) => todo!(), DType::U8 => "uint8", DType::U32 => "uint32", DType::I64 => "int64", @@ -195,6 +196,7 @@ pub fn multi_block_sort( }; // Copy output with appropriate strides let copy_kernel = match dtype { + DType::FP8(_) => todo!(), DType::U8 => crate::copy2d::U8, DType::U32 => crate::copy2d::U32, DType::I64 => crate::copy2d::I64, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 21ade21c..b996820e 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,4 +1,6 @@ use super::*; + +use float8::{F8E4M3, F8E5M2}; use half::{bf16, f16}; use metal::{Buffer, Device, MTLResourceOptions}; use rand::prelude::SliceRandom; @@ -423,6 +425,38 @@ fn cast_bf16() { assert_eq!(results, v_i64); } +#[test] +fn cast_fp8() { + let v_f64 = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + + let v_f8: Vec = v_f32.iter().map(|&v| F8E4M3::from(v)).collect(); + let v_f16_rt: Vec = v_f8.iter().map(|&v| v.to_f16()).collect(); + + // E4M3 + // f32 -> fp8 -> f32 + let results: Vec = run_cast(&v_f32, "cast_f32_fp8_E4M3"); + let roundtrip: Vec = run_cast(&results, "cast_fp8_f32_E4M3"); + assert_eq!(v_f32, roundtrip); + + // f16 -> fp8 -> f16 + let results: Vec = run_cast(&v_f16, "cast_f16_fp8_E4M3"); + let roundtrip: Vec = run_cast(&results, "cast_fp8_f16_E4M3"); + assert_eq!(v_f16, roundtrip); + + // E5M2 + // f16 -> fp8 -> f16 + let results: Vec = run_cast(&v_f16, "cast_f16_fp8_E5M2"); + let roundtrip: Vec = run_cast(&results, "cast_fp8_f16_E5M2"); + assert_eq!(v_f16, roundtrip); + + // f32 -> fp8 -> f32 + let results: Vec = run_cast(&v_f32, "cast_f32_fp8_E5M2"); + let roundtrip: Vec = run_cast(&results, "cast_fp8_f32_E5M2"); + assert_eq!(v_f32, roundtrip); +} + #[test] fn cast_u32() { let v_f64 = [1.0f64, 2.0, 3.0]; @@ -1285,7 +1319,7 @@ fn where_cond_u32_f32() { } #[allow(clippy::too_many_arguments)] -fn run_mlx_gemm( +fn run_mlx_gemm( dtype: GemmDType, (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], @@ -1294,7 +1328,7 @@ fn run_mlx_gemm( rhs: &[T], rhs_stride: &[usize], rhs_offset: usize, -) -> Vec { +) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -1436,6 +1470,33 @@ fn mlx_gemm() { vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] ); } + + { + // F8E5M2 sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec = (0..b * m * k) + .map(|f| F8E5M2::from_f32(f as f32).to_bits()) + .collect(); + let rhs: Vec = (0..b * n * k) + .map(|f| F8E5M2::from_f32(f as f32).to_bits()) + .collect(); + let results: Vec = run_mlx_gemm( + GemmDType::F8E5M2, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + println!("{results:?}"); + + assert_eq!( + approx_f16(results, 4), + vec![20.0, 21.0, 26.0, 31.0, 56.0, 63.0, 80.0, 97.0] + ); + } } fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec {