From e316cb699743b5d45ab4a1067057b8f6d8687a02 Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Sun, 17 Mar 2024 15:55:11 -0400 Subject: [PATCH] add support for casting between all datatypes (#1860) --- candle-core/src/metal_backend.rs | 29 +++- candle-metal-kernels/src/cast.metal | 57 +++++-- candle-metal-kernels/src/tests.rs | 226 +++++++++++++++++++--------- 3 files changed, 220 insertions(+), 92 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 2e07cce5..a6513b1c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -609,28 +609,41 @@ impl BackendStorage for MetalStorage { let command_buffer = device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { - (DType::U32, DType::F32) => "cast_u32_f32", - (DType::U32, DType::U8) => "cast_u32_u8", - (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U32, DType::F16) => "cast_u32_f16", + (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::U8) => "cast_u32_u8", - (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", - (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", - (DType::F32, DType::F16) => "cast_f32_f16", (DType::F32, DType::BF16) => "cast_f32_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I64) => "cast_f32_i64", + (DType::F32, DType::U32) => "cast_f32_u32", + (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I64, DType::BF16) => "cast_i64_bf16", + (DType::I64, DType::F16) => "cast_i64_f16", (DType::I64, DType::F32) => "cast_i64_f32", + (DType::I64, DType::U32) => "cast_i64_u32", + (DType::I64, DType::U8) => "cast_i64_u8", (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I64) => "cast_f16_i64", + (DType::F16, DType::U32) => "cast_f16_u32", + (DType::F16, DType::U8) => "cast_f16_u8", - (DType::BF16, DType::U8) => "cast_bf16_u8", - (DType::BF16, DType::U32) => "cast_bf16_u32", (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I64) => "cast_bf16_i64", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::U8) => "cast_bf16_u8", (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 9aead139..2af3fdce 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \ output[tid] = static_cast(static_cast(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ +// u32 CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) -CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) -CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) -CAST(cast_f16_f32, cast_f16_f32_strided, half, float) -CAST(cast_f32_f16, cast_f32_f16_strided, float, half) - +CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) #if __METAL_VERSION__ >= 220 -CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) -CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) #endif +// u8 +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) +CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +#if __METAL_VERSION__ >= 220 +CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +#endif + +// f16 +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) +CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif + +// i64 +CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) +CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) +#endif + +// f32 +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) +CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) +CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) +#endif + +// bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) -CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) -CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) -CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) - CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) -CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif \ No newline at end of file diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b47fff6a..b2f1d723 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -292,7 +292,7 @@ fn binary_ops_bf16() { binary_op!(max, |x: bf16, y| x.max(y)); } -fn cast(v: &[T], name: &'static str) -> Vec { +fn run_cast(v: &[T], name: &'static str) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -319,107 +319,189 @@ fn cast(v: &[T], name: &'static str) -> Vec { } #[test] -fn cast_u32_f32() { - let v = vec![1u32, 2, 3]; - let results = cast(&v, "cast_u32_f32"); - let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); - assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); - assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); +fn cast_f32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let v = vec![1.0f32, 2.0, 3.0]; - let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec = cast(&input, "cast_f16_f32"); - assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + // f32 -> f16 + let results: Vec = run_cast(&v_f32, "cast_f32_f16"); + assert_eq!(results, v_f16); - let v = vec![1.0f32; 10_000]; - let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec = cast(&input, "cast_f16_f32"); - assert_eq!(results.len(), 10_000); - assert_eq!(&results[..10], vec![1.0f32; 10]); - assert_eq!(results, vec![1.0f32; 10_000]); + // f32 -> bf16 + let results: Vec = run_cast(&v_f32, "cast_f32_bf16"); + assert_eq!(results, v_bf16); + + // f32 -> u32 + let results: Vec = run_cast(&v_f32, "cast_f32_u32"); + assert_eq!(results, v_u32); + + // f32 -> u8 + let results: Vec = run_cast(&v_f32, "cast_f32_u8"); + assert_eq!(results, v_u8); + + // f32 -> i64 + let results: Vec = run_cast(&v_f32, "cast_f32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_u32() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); +fn cast_f16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_bf16_u32"); - let expected: Vec = (1..=3).map(|v| v as u32).collect(); + // f16 -> f32 + let results: Vec = run_cast(&v_f16, "cast_f16_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // f16 -> bf16 + let results: Vec = run_cast(&v_f16, "cast_f16_bf16"); + assert_eq!(results, v_bf16); + + // f16 -> u32 + let results: Vec = run_cast(&v_f16, "cast_f16_u32"); + assert_eq!(results, v_u32); + + // f16 -> u8 + let results: Vec = run_cast(&v_f16, "cast_f16_u8"); + assert_eq!(results, v_u8); + + // f16 -> i64 + let results: Vec = run_cast(&v_f16, "cast_f16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_f32() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); +fn cast_bf16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_bf16_f32"); - let expected: Vec = (1..=3).map(|v| v as f32).collect(); + // bf16 -> f32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // bf16 -> f16 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f16"); + assert_eq!(results, v_f16); + + // bf16 -> u32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u32"); + assert_eq!(results, v_u32); + + // bf16 -> u8 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u8"); + assert_eq!(results, v_u8); + + // bf16 -> i64 + let results: Vec = run_cast(&v_bf16, "cast_bf16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u8_bf16() { - let input: Vec = (1..=3).map(|v| v as u8).collect(); +fn cast_u32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_u8_bf16"); - let expected: Vec = input - .iter() - .map(|v| bf16::from_f32(*v as f32)) - .collect::>(); + // u32 -> f32 + let results: Vec = run_cast(&v_u32, "cast_u32_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // u32 -> f16 + let results: Vec = run_cast(&v_u32, "cast_u32_f16"); + assert_eq!(results, v_f16); + + // u32 -> bf16 + let results: Vec = run_cast(&v_u32, "cast_u32_bf16"); + assert_eq!(results, v_bf16); + + // u32 -> u8 + let results: Vec = run_cast(&v_u32, "cast_u32_u8"); + assert_eq!(results, v_u8); + + // u32 -> i64 + let results: Vec = run_cast(&v_u32, "cast_u32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u32_bf16() { - let input: Vec = (1..=3).map(|v| v as u32).collect(); +fn cast_u8() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_u32_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + // u8 -> f32 + let results: Vec = run_cast(&v_u8, "cast_u8_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // u8 -> f16 + let results: Vec = run_cast(&v_u8, "cast_u8_f16"); + assert_eq!(results, v_f16); + + // u8 -> bf16 + let results: Vec = run_cast(&v_u8, "cast_u8_bf16"); + assert_eq!(results, v_bf16); + + // u8 -> u32 + let results: Vec = run_cast(&v_u8, "cast_u8_u32"); + assert_eq!(results, v_u32); + + // u8 -> i64 + let results: Vec = run_cast(&v_u8, "cast_u8_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_f32_bf16() { - let input: Vec = (1..=3).map(|v| v as f32).collect(); +fn cast_i64() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_f32_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + // i64 -> f32 + let results: Vec = run_cast(&v_i64, "cast_i64_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); -} + // i64 -> f16 + let results: Vec = run_cast(&v_i64, "cast_i64_f16"); + assert_eq!(results, v_f16); -#[test] -fn it_cast_bf16_u8() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + // i64 -> bf16 + let results: Vec = run_cast(&v_i64, "cast_i64_bf16"); + assert_eq!(results, v_bf16); - let output: Vec = cast(&input, "cast_bf16_u8"); - let expected: Vec = input.iter().map(|v| v.to_f32() as u8).collect(); + // i64 -> u32 + let results: Vec = run_cast(&v_i64, "cast_i64_u32"); + assert_eq!(results, v_u32); - assert_eq!(output, expected); -} - -#[test] -fn it_cast_bf16_f16() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec = cast(&input, "cast_bf16_f16"); - let expected: Vec = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); -} - -#[test] -fn it_cast_f16_bf16() { - let input: Vec = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); - - let output: Vec = cast(&input, "cast_f16_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); + // i64 -> u8 + let results: Vec = run_cast(&v_i64, "cast_i64_u8"); + assert_eq!(results, v_u8); } fn run_affine(v: &[T], mul: f64, add: f64) -> Vec {