diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index d9fc7526..92a04917 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false - diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 5d72bd68..aa2898ff 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", - (DType::F16, DType::F32) => "cast_f16_f32", - (DType::I64, DType::F32) => "cast_i64_f32", (DType::F32, DType::BF16) => "cast_f32_bf16", + + (DType::I64, DType::F32) => "cast_i64_f32", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + + (DType::BF16, DType::U8) => "cast_bf16_u8", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } @@ -1322,6 +1338,7 @@ impl MetalStorage { ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), @@ -1332,6 +1349,18 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1342,6 +1371,7 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), @@ -1352,6 +1382,7 @@ impl MetalStorage { ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), @@ -1362,6 +1393,7 @@ impl MetalStorage { ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1395,6 +1427,7 @@ impl MetalStorage { ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), @@ -1407,6 +1440,20 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1419,6 +1466,7 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype), @@ -1431,6 +1479,7 @@ impl MetalStorage { ("lt", DType::U32) => (strided::lt::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype), @@ -1443,6 +1492,7 @@ impl MetalStorage { ("lt", DType::U8) => (strided::lt::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88..187cb4fd 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"] categories = ["science"] license = "MIT OR Apache-2.0" + [dependencies] -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 5aacac4a..e08931cf 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -28,7 +28,7 @@ kernel void FN_NAME( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[tid]); \ + output[tid] = static_cast(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ + output[tid] = static_cast(input[get_strided_index(tid, num_dims, dims, strides)]); \ +} \ + +#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[tid])); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) @@ -59,6 +86,15 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif #if defined(__HAVE_BFLOAT__) +#if __METAL_VERSION__ >= 310 +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) + +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) + +CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) +CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 32f3f410..2a57bdbb 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -174,6 +174,9 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half) #if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) + INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..87f8ac45 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Device, MTLResourceOptions}; +use metal::{Buffer, Device, MTLResourceOptions}; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; @@ -248,6 +248,34 @@ fn binary_add_f32() { assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); } +#[test] +fn binary_ops_bf16() { + let lhs: Vec = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); let fence = device.new_fence(); @@ -296,6 +324,89 @@ fn cast_u32_f32() { assert_eq!(results, vec![1.0f32; 10_000]); } +#[test] +fn it_cast_bf16_u32() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_u32"); + let expected: Vec = (1..=3).map(|v| v as u32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f32() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_f32"); + let expected: Vec = (1..=3).map(|v| v as f32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u8_bf16() { + let input: Vec = (1..=3).map(|v| v as u8).collect(); + + let output: Vec = cast(&input, "cast_u8_bf16"); + let expected: Vec = input + .iter() + .map(|v| bf16::from_f32(*v as f32)) + .collect::>(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u32_bf16() { + let input: Vec = (1..=3).map(|v| v as u32).collect(); + + let output: Vec = cast(&input, "cast_u32_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f32_bf16() { + let input: Vec = (1..=3).map(|v| v as f32).collect(); + + let output: Vec = cast(&input, "cast_f32_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_u8() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_u8"); + let expected: Vec = input.iter().map(|v| v.to_f32() as u8).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f16() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_f16"); + let expected: Vec = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f16_bf16() { + let input: Vec = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_f16_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); let fence = device.new_fence(); @@ -396,14 +507,14 @@ fn index_select() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] @@ -419,20 +530,46 @@ fn index_select_f16() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] ); } +#[test] +fn index_select_is_u32_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + #[test] fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -444,6 +581,7 @@ fn run_index_select( shape: &[usize], ids: &[I], dim: usize, + name: &'static str, ) -> Vec { let device = Device::system_default().expect("no device found"); @@ -457,12 +595,6 @@ fn run_index_select( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let name = match core::mem::size_of::() { - 4 => "is_u32_f32", - 2 => "is_u32_f16", - _ => unimplemented!(), - }; - let fence = device.new_fence(); let kernels = Kernels::new(fence); call_index_select(