add support for casting between all datatypes (#1860)

This commit is contained in:
Thomas Santerre
2024-03-17 15:55:11 -04:00
committed by GitHub
parent ce9fbc3682
commit e316cb6997
3 changed files with 220 additions and 92 deletions

View File

@ -609,28 +609,41 @@ impl BackendStorage for MetalStorage {
let command_buffer = device.command_buffer()?; let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) { let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::F16) => "cast_u8_f16",
(DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::U32) => "cast_u8_u32",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::I64) => "cast_f32_i64",
(DType::F32, DType::U32) => "cast_f32_u32",
(DType::F32, DType::U8) => "cast_f32_u8",
(DType::I64, DType::BF16) => "cast_i64_bf16",
(DType::I64, DType::F16) => "cast_i64_f16",
(DType::I64, DType::F32) => "cast_i64_f32", (DType::I64, DType::F32) => "cast_i64_f32",
(DType::I64, DType::U32) => "cast_i64_u32",
(DType::I64, DType::U8) => "cast_i64_u8",
(DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32", (DType::F16, DType::F32) => "cast_f16_f32",
(DType::F16, DType::I64) => "cast_f16_i64",
(DType::F16, DType::U32) => "cast_f16_u32",
(DType::F16, DType::U8) => "cast_f16_u8",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32", (DType::BF16, DType::F32) => "cast_bf16_f32",
(DType::BF16, DType::I64) => "cast_bf16_i64",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")

View File

@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \ } \
// u32
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 220 #if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
#endif #endif
// u8
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
#endif
// f16
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
// i64
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif
// f32
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif
// bf16
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif #endif

View File

@ -292,7 +292,7 @@ fn binary_ops_bf16() {
binary_op!(max, |x: bf16, y| x.max(y)); binary_op!(max, |x: bf16, y| x.max(y));
} }
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device(); let device = device();
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
} }
#[test] #[test]
fn cast_u32_f32() { fn cast_f32() {
let v = vec![1u32, 2, 3]; let v_f64 = vec![1.0f64, 2.0, 3.0];
let results = cast(&v, "cast_u32_f32"); let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let v = vec![1.0f32, 2.0, 3.0]; // f32 -> f16
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
let results: Vec<f32> = cast(&input, "cast_f16_f32"); assert_eq!(results, v_f16);
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000]; // f32 -> bf16
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
let results: Vec<f32> = cast(&input, "cast_f16_f32"); assert_eq!(results, v_bf16);
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]); // f32 -> u32
assert_eq!(results, vec![1.0f32; 10_000]); let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
assert_eq!(results, v_u32);
// f32 -> u8
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
assert_eq!(results, v_u8);
// f32 -> i64
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
assert_eq!(results, v_i64);
} }
#[test] #[test]
fn it_cast_bf16_u32() { fn cast_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32"); // f16 -> f32
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect(); let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected); // f16 -> bf16
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
assert_eq!(results, v_bf16);
// f16 -> u32
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
assert_eq!(results, v_u32);
// f16 -> u8
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
assert_eq!(results, v_u8);
// f16 -> i64
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
assert_eq!(results, v_i64);
} }
#[test] #[test]
fn it_cast_bf16_f32() { fn cast_bf16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32"); // bf16 -> f32
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect(); let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected); // bf16 -> f16
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
assert_eq!(results, v_f16);
// bf16 -> u32
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
assert_eq!(results, v_u32);
// bf16 -> u8
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
assert_eq!(results, v_u8);
// bf16 -> i64
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
assert_eq!(results, v_i64);
} }
#[test] #[test]
fn it_cast_u8_bf16() { fn cast_u32() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect(); let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16"); // u32 -> f32
let expected: Vec<bf16> = input let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
.iter() assert_eq!(results, v_f32);
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
assert_eq!(output, expected); // u32 -> f16
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
assert_eq!(results, v_f16);
// u32 -> bf16
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
assert_eq!(results, v_bf16);
// u32 -> u8
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
assert_eq!(results, v_u8);
// u32 -> i64
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
assert_eq!(results, v_i64);
} }
#[test] #[test]
fn it_cast_u32_bf16() { fn cast_u8() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect(); let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16"); // u8 -> f32
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected); // u8 -> f16
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
assert_eq!(results, v_f16);
// u8 -> bf16
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
assert_eq!(results, v_bf16);
// u8 -> u32
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
assert_eq!(results, v_u32);
// u8 -> i64
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
assert_eq!(results, v_i64);
} }
#[test] #[test]
fn it_cast_f32_bf16() { fn cast_i64() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect(); let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16"); // i64 -> f32
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected); // i64 -> f16
} let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
assert_eq!(results, v_f16);
#[test] // i64 -> bf16
fn it_cast_bf16_u8() { let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); assert_eq!(results, v_bf16);
let output: Vec<u8> = cast(&input, "cast_bf16_u8"); // i64 -> u32
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect(); let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
assert_eq!(results, v_u32);
assert_eq!(output, expected); // i64 -> u8
} let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
assert_eq!(results, v_u8);
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
} }
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {