add support for casting between all datatypes (#1860)

This commit is contained in:
Thomas Santerre
2024-03-17 15:55:11 -04:00
committed by GitHub
parent ce9fbc3682
commit e316cb6997
3 changed files with 220 additions and 92 deletions

View File

@ -292,7 +292,7 @@ fn binary_ops_bf16() {
binary_op!(max, |x: bf16, y| x.max(y));
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
fn cast_f32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let v = vec![1.0f32, 2.0, 3.0];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
// f32 -> f16
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
assert_eq!(results, v_f16);
let v = vec![1.0f32; 10_000];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]);
assert_eq!(results, vec![1.0f32; 10_000]);
// f32 -> bf16
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
assert_eq!(results, v_bf16);
// f32 -> u32
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
assert_eq!(results, v_u32);
// f32 -> u8
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
assert_eq!(results, v_u8);
// f32 -> i64
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_bf16_u32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
fn cast_f16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
// f16 -> f32
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// f16 -> bf16
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
assert_eq!(results, v_bf16);
// f16 -> u32
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
assert_eq!(results, v_u32);
// f16 -> u8
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
assert_eq!(results, v_u8);
// f16 -> i64
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_bf16_f32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
fn cast_bf16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
// bf16 -> f32
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// bf16 -> f16
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
assert_eq!(results, v_f16);
// bf16 -> u32
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
assert_eq!(results, v_u32);
// bf16 -> u8
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
assert_eq!(results, v_u8);
// bf16 -> i64
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_u8_bf16() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
fn cast_u32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
let expected: Vec<bf16> = input
.iter()
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
// u32 -> f32
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// u32 -> f16
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
assert_eq!(results, v_f16);
// u32 -> bf16
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
assert_eq!(results, v_bf16);
// u32 -> u8
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
assert_eq!(results, v_u8);
// u32 -> i64
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_u32_bf16() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
fn cast_u8() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
// u8 -> f32
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// u8 -> f16
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
assert_eq!(results, v_f16);
// u8 -> bf16
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
assert_eq!(results, v_bf16);
// u8 -> u32
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
assert_eq!(results, v_u32);
// u8 -> i64
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_f32_bf16() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
fn cast_i64() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
// i64 -> f32
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
}
// i64 -> f16
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
assert_eq!(results, v_f16);
#[test]
fn it_cast_bf16_u8() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
// i64 -> bf16
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
assert_eq!(results, v_bf16);
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
// i64 -> u32
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
assert_eq!(results, v_u32);
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
// i64 -> u8
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
assert_eq!(results, v_u8);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {