mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00

* initial implementation * use correct index, but still not breaking like it should have... * fix test
2005 lines
56 KiB
Rust
2005 lines
56 KiB
Rust
use super::*;
|
|
use half::{bf16, f16};
|
|
use metal::MTLResourceOptions;
|
|
|
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|
let ptr = buffer.contents() as *const T;
|
|
assert!(!ptr.is_null());
|
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
|
slice.to_vec()
|
|
}
|
|
|
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let ptr = data.as_ptr() as *const c_void;
|
|
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
|
device.new_buffer_with_data(ptr, size, options)
|
|
}
|
|
|
|
fn device() -> Device {
|
|
Device::system_default().unwrap()
|
|
}
|
|
|
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
|
let b = 10f32.powi(digits);
|
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
|
}
|
|
|
|
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
|
let b = 10f32.powi(digits);
|
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
}
|
|
|
|
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
|
let b = 10f32.powi(digits);
|
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
}
|
|
|
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, v);
|
|
call_unary_contiguous(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
v.len(),
|
|
&input,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, v.len())
|
|
}
|
|
|
|
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let left = new_buffer(&device, x);
|
|
let right = new_buffer(&device, y);
|
|
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
|
call_binary_contiguous(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
x.len(),
|
|
&left,
|
|
&right,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, x.len())
|
|
}
|
|
|
|
fn run_strided<T: Clone>(
|
|
v: &[T],
|
|
kernel: unary::strided::Kernel,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
offset: usize,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, v);
|
|
let kernels = Kernels::new();
|
|
call_unary_strided(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
kernel,
|
|
shape,
|
|
&input,
|
|
strides,
|
|
offset,
|
|
&output,
|
|
0,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, v.len())
|
|
}
|
|
|
|
#[test]
|
|
fn cos_f32() {
|
|
let v = vec![1.0f32, 2.0, 3.0];
|
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
|
|
|
let v = vec![1.0f32; 10_000];
|
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
|
}
|
|
|
|
#[test]
|
|
fn cos_f32_strided() {
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let shape = vec![6];
|
|
let strides = vec![1];
|
|
let offset = 0;
|
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
);
|
|
assert_eq!(
|
|
approx(expected, 4),
|
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
);
|
|
|
|
// Contiguous
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let shape = vec![3, 2];
|
|
let strides = vec![2, 1];
|
|
let offset = 0;
|
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
);
|
|
assert_eq!(
|
|
approx(expected, 4),
|
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
);
|
|
|
|
// Transposed
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let shape = vec![3, 2];
|
|
let strides = vec![1, 3];
|
|
let offset = 0;
|
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
|
);
|
|
assert_eq!(
|
|
approx(expected, 4),
|
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
);
|
|
|
|
// Very large
|
|
let v = vec![1.0f32; 10_000];
|
|
let shape = vec![2, 5_000];
|
|
let strides = vec![2, 1];
|
|
let offset = 0;
|
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
|
}
|
|
|
|
#[test]
|
|
fn cos_strided_random() {
|
|
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
|
let shape = vec![5_000, 2];
|
|
let strides = vec![1, 5_000];
|
|
let offset = 0;
|
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
|
assert_eq!(
|
|
approx(vec![results[1]], 4),
|
|
approx(vec![expected[5_000]], 4)
|
|
);
|
|
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
|
assert_eq!(
|
|
approx(vec![results[3]], 4),
|
|
approx(vec![expected[5_001]], 4)
|
|
);
|
|
assert_eq!(
|
|
approx(vec![results[5_000]], 4),
|
|
approx(vec![expected[2_500]], 4)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn gelu_f16() {
|
|
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
|
|
let results = run(&v, unary::contiguous::gelu::HALF);
|
|
assert_eq!(approx_f16(results, 2), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn gelu_f32() {
|
|
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
|
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
|
|
let results = run(&v, unary::contiguous::gelu::FLOAT);
|
|
assert_eq!(approx(results, 3), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn silu_f16() {
|
|
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
|
let results = run(&v, unary::contiguous::silu::HALF);
|
|
assert_eq!(approx_f16(results, 2), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn silu_f32() {
|
|
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
|
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
|
let results = run(&v, unary::contiguous::silu::FLOAT);
|
|
assert_eq!(approx(results, 3), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn binary_add_f32() {
|
|
let left = vec![1.0f32, 2.0, 3.0];
|
|
let right = vec![2.0f32, 3.1, 4.2];
|
|
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
|
let expected: Vec<_> = left
|
|
.iter()
|
|
.zip(right.iter())
|
|
.map(|(&x, &y)| x + y)
|
|
.collect();
|
|
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
|
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
|
}
|
|
|
|
#[test]
|
|
fn binary_ops_bf16() {
|
|
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
|
|
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
|
|
.into_iter()
|
|
.map(bf16::from_f32)
|
|
.collect();
|
|
|
|
macro_rules! binary_op {
|
|
($opname:ident, $opexpr:expr) => {{
|
|
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
|
|
let expected: Vec<bf16> = lhs
|
|
.iter()
|
|
.zip(rhs.iter())
|
|
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
|
|
.collect();
|
|
assert_eq!(results, expected);
|
|
}};
|
|
}
|
|
|
|
binary_op!(add, |x, y| x + y);
|
|
binary_op!(sub, |x, y| x - y);
|
|
binary_op!(mul, |x, y| x * y);
|
|
binary_op!(div, |x, y| x / y);
|
|
binary_op!(min, |x: bf16, y| x.min(y));
|
|
binary_op!(max, |x: bf16, y| x.max(y));
|
|
}
|
|
|
|
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = new_buffer(&device, v);
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
|
let output = device.new_buffer(size, options);
|
|
|
|
call_cast_contiguous(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
v.len(),
|
|
&input,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, v.len())
|
|
}
|
|
|
|
#[test]
|
|
fn cast_f32() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// f32 -> f16
|
|
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
|
|
assert_eq!(results, v_f16);
|
|
|
|
// f32 -> bf16
|
|
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
|
|
assert_eq!(results, v_bf16);
|
|
|
|
// f32 -> u32
|
|
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
|
|
assert_eq!(results, v_u32);
|
|
|
|
// f32 -> u8
|
|
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
|
|
assert_eq!(results, v_u8);
|
|
|
|
// f32 -> i64
|
|
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
|
|
assert_eq!(results, v_i64);
|
|
}
|
|
|
|
#[test]
|
|
fn cast_f16() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// f16 -> f32
|
|
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
|
|
assert_eq!(results, v_f32);
|
|
|
|
// f16 -> bf16
|
|
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
|
|
assert_eq!(results, v_bf16);
|
|
|
|
// f16 -> u32
|
|
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
|
|
assert_eq!(results, v_u32);
|
|
|
|
// f16 -> u8
|
|
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
|
|
assert_eq!(results, v_u8);
|
|
|
|
// f16 -> i64
|
|
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
|
|
assert_eq!(results, v_i64);
|
|
}
|
|
|
|
#[test]
|
|
fn cast_bf16() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// bf16 -> f32
|
|
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
|
|
assert_eq!(results, v_f32);
|
|
|
|
// bf16 -> f16
|
|
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
|
|
assert_eq!(results, v_f16);
|
|
|
|
// bf16 -> u32
|
|
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
|
|
assert_eq!(results, v_u32);
|
|
|
|
// bf16 -> u8
|
|
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
|
|
assert_eq!(results, v_u8);
|
|
|
|
// bf16 -> i64
|
|
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
|
|
assert_eq!(results, v_i64);
|
|
}
|
|
|
|
#[test]
|
|
fn cast_u32() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// u32 -> f32
|
|
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
|
|
assert_eq!(results, v_f32);
|
|
|
|
// u32 -> f16
|
|
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
|
|
assert_eq!(results, v_f16);
|
|
|
|
// u32 -> bf16
|
|
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
|
|
assert_eq!(results, v_bf16);
|
|
|
|
// u32 -> u8
|
|
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
|
|
assert_eq!(results, v_u8);
|
|
|
|
// u32 -> i64
|
|
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
|
|
assert_eq!(results, v_i64);
|
|
}
|
|
|
|
#[test]
|
|
fn cast_u8() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// u8 -> f32
|
|
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
|
|
assert_eq!(results, v_f32);
|
|
|
|
// u8 -> f16
|
|
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
|
|
assert_eq!(results, v_f16);
|
|
|
|
// u8 -> bf16
|
|
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
|
|
assert_eq!(results, v_bf16);
|
|
|
|
// u8 -> u32
|
|
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
|
|
assert_eq!(results, v_u32);
|
|
|
|
// u8 -> i64
|
|
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
|
|
assert_eq!(results, v_i64);
|
|
}
|
|
|
|
#[test]
|
|
fn cast_i64() {
|
|
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
|
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
|
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
|
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
|
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
|
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
|
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
|
|
|
// i64 -> f32
|
|
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
|
|
assert_eq!(results, v_f32);
|
|
|
|
// i64 -> f16
|
|
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
|
|
assert_eq!(results, v_f16);
|
|
|
|
// i64 -> bf16
|
|
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
|
|
assert_eq!(results, v_bf16);
|
|
|
|
// i64 -> u32
|
|
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
|
|
assert_eq!(results, v_u32);
|
|
|
|
// i64 -> u8
|
|
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
|
|
assert_eq!(results, v_u8);
|
|
}
|
|
|
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, v);
|
|
|
|
let size = v.len();
|
|
|
|
call_affine(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
"affine_f32",
|
|
size,
|
|
&input,
|
|
&output,
|
|
mul as f32,
|
|
add as f32,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, v.len())
|
|
}
|
|
|
|
fn run_affine_strided<T: Clone>(
|
|
v: &[T],
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
mul: f64,
|
|
add: f64,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, v);
|
|
|
|
call_affine_strided(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
"affine_f32_strided",
|
|
shape,
|
|
&input,
|
|
strides,
|
|
0,
|
|
&output,
|
|
mul as f32,
|
|
add as f32,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
let len: usize = shape.iter().product();
|
|
read_to_vec(&output, len)
|
|
}
|
|
|
|
#[test]
|
|
fn affine() {
|
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let mul = 1.5;
|
|
let add = 1.1;
|
|
let result = run_affine(&input, mul, add);
|
|
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
|
|
|
let input = [1.0f32; 40_000];
|
|
let mul = 1.5;
|
|
let add = 1.1;
|
|
let result = run_affine(&input, mul, add);
|
|
assert_eq!(result, vec![2.6; 40_000]);
|
|
}
|
|
|
|
#[test]
|
|
fn affine_strided() {
|
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let mul = 1.5;
|
|
let add = 1.1;
|
|
let shape = [4];
|
|
let strides = [2];
|
|
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
|
// 1 on 2
|
|
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select() {
|
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
let shape = [5, 2];
|
|
let stride = [2, 1];
|
|
let ids = [0u32, 4, 2];
|
|
let dim = 0;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
|
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
|
|
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
let shape = [2, 5];
|
|
let stride = [1, 2];
|
|
let ids = [0u32, 1, 0];
|
|
let dim = 0;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
|
assert_eq!(
|
|
result,
|
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select_strided() {
|
|
let embedding = (0..16).map(|x| x as f32).collect::<Vec<_>>();
|
|
let shape = [2, 2];
|
|
let stride = [2, 4];
|
|
let ids = [0u32];
|
|
let dim = 0;
|
|
let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
|
assert_eq!(result, vec![0.0, 4.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select_f16() {
|
|
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
.into_iter()
|
|
.map(|x| f16::from_f32(x))
|
|
.collect();
|
|
let shape = [5, 2];
|
|
let stride = [2, 1];
|
|
let ids = [0u32, 4, 2];
|
|
let dim = 0;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16");
|
|
assert_eq!(
|
|
approx_f16(result, 4),
|
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select_is_u32_bf16() {
|
|
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
|
let shape = [5, 2];
|
|
let stride = [2, 1];
|
|
let ids = [0u32, 4, 2];
|
|
let dim = 0;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16");
|
|
assert_eq!(
|
|
approx_bf16(result, 4),
|
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select_is_u8_bf16() {
|
|
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
|
let shape = [5, 2];
|
|
let stride = [2, 1];
|
|
let ids = [0u8, 4, 2];
|
|
let dim = 0;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16");
|
|
assert_eq!(
|
|
approx_bf16(result, 4),
|
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn index_select_dim1() {
|
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
let shape = [5, 2];
|
|
let stride = [2, 1];
|
|
let ids = [0u32, 1, 0];
|
|
let dim = 1;
|
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
|
assert_eq!(
|
|
result,
|
|
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
|
);
|
|
}
|
|
|
|
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|
embeddings: &[T],
|
|
shape: &[usize],
|
|
stride: &[usize],
|
|
ids: &[I],
|
|
dim: usize,
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = Device::system_default().expect("no device found");
|
|
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
|
let ids_buffer = new_buffer(&device, &ids);
|
|
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let dst_el = ids.len() * left_size * right_size;
|
|
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
|
|
|
let kernels = Kernels::new();
|
|
call_index_select(
|
|
&device,
|
|
&command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
ids.len(),
|
|
dim,
|
|
true,
|
|
shape,
|
|
stride,
|
|
&embeddings_buffer,
|
|
0,
|
|
&ids_buffer,
|
|
0,
|
|
&dst_buffer,
|
|
)
|
|
.unwrap();
|
|
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&dst_buffer, dst_el)
|
|
}
|
|
|
|
fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
|
embeddings: &[T],
|
|
shape: &[usize],
|
|
stride: &[usize],
|
|
ids: &[I],
|
|
dim: usize,
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = Device::system_default().expect("no device found");
|
|
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
|
let ids_buffer = new_buffer(&device, &ids);
|
|
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let dst_el = ids.len() * left_size * right_size;
|
|
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
|
|
|
let kernels = Kernels::new();
|
|
call_index_select(
|
|
&device,
|
|
&command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
ids.len(),
|
|
dim,
|
|
false,
|
|
shape,
|
|
stride,
|
|
&embeddings_buffer,
|
|
0,
|
|
&ids_buffer,
|
|
0,
|
|
&dst_buffer,
|
|
)
|
|
.unwrap();
|
|
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&dst_buffer, dst_el)
|
|
}
|
|
|
|
#[test]
|
|
fn cos_f16() {
|
|
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let results = run(&v, unary::contiguous::cos::HALF);
|
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
|
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
|
|
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
|
}
|
|
|
|
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = new_buffer(&device, v);
|
|
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
|
let dims = vec![v.len()];
|
|
let strides = vec![1];
|
|
call_reduce_strided(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
&dims,
|
|
&strides,
|
|
out_length,
|
|
&input,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, out_length)
|
|
}
|
|
|
|
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, v);
|
|
call_last_softmax(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
v.len(),
|
|
last_dim,
|
|
&input,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, v.len())
|
|
}
|
|
|
|
#[test]
|
|
fn reduce_sum() {
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let out_length = 1;
|
|
|
|
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
|
assert_eq!(approx(results, 4), vec![21.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn reduce_sum2() {
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let out_length = 2;
|
|
|
|
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
|
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn softmax() {
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let last_dim = 6;
|
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
|
);
|
|
|
|
let last_dim = 4096;
|
|
let n = 200;
|
|
let mut v = vec![0.0; n * last_dim];
|
|
for i in 0..n {
|
|
v[i * last_dim] = 20.0;
|
|
}
|
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
|
let results = approx(results, 4);
|
|
assert_eq!(
|
|
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
|
n
|
|
);
|
|
assert_eq!(results[0], 1.0);
|
|
assert_eq!(results[1], 0.0);
|
|
assert_eq!(results[last_dim], 1.0);
|
|
assert_eq!(results[2 * last_dim], 1.0);
|
|
|
|
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
|
let last_dim = 6;
|
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
|
);
|
|
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let last_dim = 3;
|
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
|
);
|
|
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
let last_dim = 6;
|
|
let results = run_softmax(&v, last_dim, "softmax_f16");
|
|
assert_eq!(
|
|
approx_f16(results, 4),
|
|
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
|
);
|
|
|
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
let last_dim = 6;
|
|
let results = run_softmax(&v, last_dim, "softmax_bf16");
|
|
assert_eq!(
|
|
approx_bf16(results, 4),
|
|
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
|
);
|
|
}
|
|
|
|
fn run_where_cond<I: Clone, T: Clone>(
|
|
shape: &[usize],
|
|
cond: &[I],
|
|
(cond_stride, cond_offset): (Vec<usize>, usize),
|
|
left_true: &[T],
|
|
(left_stride, left_offset): (Vec<usize>, usize),
|
|
right_false: &[T],
|
|
(_right_stride, _right_offset): (Vec<usize>, usize),
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
|
|
let length = cond.len();
|
|
let cond = device.new_buffer_with_data(
|
|
cond.as_ptr() as *const core::ffi::c_void,
|
|
std::mem::size_of_val(cond) as u64,
|
|
options,
|
|
);
|
|
let left = device.new_buffer_with_data(
|
|
left_true.as_ptr() as *const core::ffi::c_void,
|
|
(length * core::mem::size_of::<T>()) as u64,
|
|
options,
|
|
);
|
|
let right = device.new_buffer_with_data(
|
|
right_false.as_ptr() as *const core::ffi::c_void,
|
|
(length * core::mem::size_of::<T>()) as u64,
|
|
options,
|
|
);
|
|
|
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
|
call_where_cond_strided(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
&cond,
|
|
(&cond_stride, cond_offset),
|
|
&left,
|
|
(&left_stride, left_offset),
|
|
&right,
|
|
(&cond_stride, cond_offset),
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, length)
|
|
}
|
|
|
|
#[test]
|
|
fn where_cond() {
|
|
let shape = vec![6];
|
|
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
|
let cond_l = (vec![1], 0);
|
|
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let left_l = (vec![1], 0);
|
|
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
|
let right_l = (vec![1], 0);
|
|
let results = run_where_cond(
|
|
&shape,
|
|
&cond,
|
|
cond_l,
|
|
&left_true,
|
|
left_l,
|
|
&right_false,
|
|
right_l,
|
|
"where_u8_f32",
|
|
);
|
|
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
fn run_gemm<T: Clone>(
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
lhs: &[T],
|
|
lhs_stride: Vec<usize>,
|
|
lhs_offset: usize,
|
|
rhs: &[T],
|
|
rhs_stride: Vec<usize>,
|
|
rhs_offset: usize,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
|
|
let lhs = device.new_buffer_with_data(
|
|
lhs.as_ptr() as *const core::ffi::c_void,
|
|
std::mem::size_of_val(lhs) as u64,
|
|
options,
|
|
);
|
|
let rhs = device.new_buffer_with_data(
|
|
rhs.as_ptr() as *const core::ffi::c_void,
|
|
std::mem::size_of_val(rhs) as u64,
|
|
options,
|
|
);
|
|
let length = b * m * n;
|
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
|
call_gemm(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
"sgemm",
|
|
(b, m, n, k),
|
|
&lhs_stride,
|
|
lhs_offset,
|
|
&lhs,
|
|
&rhs_stride,
|
|
rhs_offset,
|
|
&rhs,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, length)
|
|
}
|
|
|
|
#[test]
|
|
fn gemm() {
|
|
let (b, m, n, k) = (1, 2, 4, 3);
|
|
let lhs_stride = vec![m * k, k, 1];
|
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
let rhs_stride = vec![n * k, n, 1];
|
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
|
);
|
|
|
|
let (b, m, n, k) = (2, 2, 4, 3);
|
|
let lhs_stride = vec![m * k, k, 1];
|
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
let rhs_stride = vec![n * k, n, 1];
|
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![
|
|
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
|
518.0, 548.0, 578.0
|
|
]
|
|
);
|
|
|
|
// OFFSET
|
|
let (b, m, n, k) = (2, 2, 4, 3);
|
|
let lhs_stride = vec![m * k, k, 1];
|
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
let rhs_stride = vec![n * k, n, 1];
|
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
|
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
|
|
assert_eq!(
|
|
approx(results, 4),
|
|
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
|
);
|
|
}
|
|
|
|
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
|
|
|
let seed = device.new_buffer_with_data(
|
|
&seed as *const u32 as *const core::ffi::c_void,
|
|
std::mem::size_of::<u32>() as NSUInteger,
|
|
options,
|
|
);
|
|
|
|
if name.starts_with("rand_uniform") {
|
|
call_random_uniform(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
a,
|
|
b,
|
|
length,
|
|
&seed,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
} else {
|
|
call_random_normal(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
a,
|
|
b,
|
|
length,
|
|
&seed,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
}
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, length)
|
|
}
|
|
|
|
#[test]
|
|
fn random() {
|
|
fn calc_mean(data: &[f32]) -> f32 {
|
|
let sum = data.iter().sum::<f32>() as f32;
|
|
let count = data.len();
|
|
assert!(count > 0);
|
|
sum / count as f32
|
|
}
|
|
|
|
fn calc_stddev(data: &[f32]) -> f32 {
|
|
let mean = calc_mean(data);
|
|
let count = data.len();
|
|
assert!(count > 0);
|
|
|
|
let variance = data
|
|
.iter()
|
|
.map(|value| {
|
|
let diff = mean - (*value as f32);
|
|
diff * diff
|
|
})
|
|
.sum::<f32>()
|
|
/ count as f32;
|
|
|
|
variance.sqrt()
|
|
}
|
|
|
|
let shape = vec![1024, 10];
|
|
|
|
let length = shape.iter().product::<usize>();
|
|
let seed = 299792458;
|
|
|
|
let min = -30.0;
|
|
let max = 30.0;
|
|
let mean = 100.0;
|
|
let stddev = 50.0;
|
|
|
|
macro_rules! validate_random {
|
|
($type:ty) => {
|
|
let results: Vec<f32> = run_random::<$type>(
|
|
concat!("rand_uniform_", stringify!($type)),
|
|
seed,
|
|
length,
|
|
min,
|
|
max,
|
|
)
|
|
.into_iter()
|
|
.map(f32::from)
|
|
.collect();
|
|
results.iter().for_each(|v| {
|
|
assert!(*v >= min && *v <= max);
|
|
});
|
|
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
|
|
|
let results: Vec<f32> = run_random::<$type>(
|
|
concat!("rand_normal_", stringify!($type)),
|
|
seed,
|
|
length,
|
|
mean,
|
|
stddev,
|
|
)
|
|
.into_iter()
|
|
.map(f32::from)
|
|
.collect();
|
|
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
|
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
|
};
|
|
}
|
|
|
|
validate_random!(f32);
|
|
validate_random!(f16);
|
|
validate_random!(bf16);
|
|
}
|
|
|
|
fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
|
input: &[T],
|
|
ids: &[I],
|
|
shape: &[usize],
|
|
dim: usize,
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let input_buffer = new_buffer(&device, input);
|
|
let ids_buffer = new_buffer(&device, ids);
|
|
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
|
call_scatter_add(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
shape,
|
|
dim,
|
|
&input_buffer,
|
|
0,
|
|
&ids_buffer,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, input.len())
|
|
}
|
|
|
|
#[test]
|
|
fn scatter_add() {
|
|
let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3];
|
|
let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3];
|
|
let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3];
|
|
|
|
let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0];
|
|
let input_f16 = input_f32
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
let input_bf16 = input_f32
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
|
|
let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0];
|
|
let output_dim1_f16 = output_dim1_f32
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
let output_dim1_bf16 = output_dim1_f32
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
|
|
let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0];
|
|
let output_dim2_f16 = output_dim2_f32
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
let output_dim2_bf16 = output_dim2_f32
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
|
|
for (shape, output_f32, output_f16, output_bf16) in [
|
|
(vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16),
|
|
(
|
|
vec![4, 2],
|
|
output_dim2_f32,
|
|
output_dim2_f16,
|
|
output_dim2_bf16,
|
|
),
|
|
] {
|
|
for results in [
|
|
run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"),
|
|
run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"),
|
|
run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"),
|
|
] {
|
|
assert_eq!(results, output_f32);
|
|
}
|
|
for results in [
|
|
run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"),
|
|
run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"),
|
|
run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"),
|
|
] {
|
|
assert_eq!(results, output_f16);
|
|
}
|
|
for results in [
|
|
run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"),
|
|
run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"),
|
|
run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"),
|
|
] {
|
|
assert_eq!(results, output_bf16);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
|
left: &[T],
|
|
right: &[T],
|
|
indices: &[I],
|
|
shape: &[usize],
|
|
dim: usize,
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let kernels = Kernels::new();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input_buffer = new_buffer(&device, right);
|
|
let output = new_buffer(&device, left);
|
|
let indices_buffer = new_buffer(&device, indices);
|
|
call_index_add(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
shape,
|
|
shape,
|
|
dim,
|
|
&input_buffer,
|
|
0,
|
|
&indices_buffer,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
read_to_vec(&output, left.len())
|
|
}
|
|
|
|
#[test]
|
|
fn index_add() {
|
|
let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0];
|
|
let indices = vec![0u32, 1, 0, 1, 0, 1];
|
|
let shape = vec![6];
|
|
|
|
// u32, f32
|
|
{
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32");
|
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// u32, f16
|
|
{
|
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16");
|
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// u32, bf16
|
|
{
|
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16");
|
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// u8, f32
|
|
{
|
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32");
|
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// u8, f16
|
|
{
|
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16");
|
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// u8, bf16
|
|
{
|
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16");
|
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// i64, f32
|
|
{
|
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32");
|
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// i64, f16
|
|
{
|
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16");
|
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
|
|
// i64, bf16
|
|
{
|
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16");
|
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
|
}
|
|
}
|
|
|
|
fn run_pool2d<T: Clone>(
|
|
v: &[T],
|
|
(w_k, h_k): (usize, usize),
|
|
(w_stride, h_stride): (usize, usize),
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let out_w = (shape[2] - w_k) / w_stride + 1;
|
|
let out_h = (shape[3] - h_k) / h_stride + 1;
|
|
let dst_el = out_w * out_h * shape[0] * shape[1];
|
|
let input = new_buffer(&device, v);
|
|
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
|
|
let kernels = Kernels::new();
|
|
call_pool2d(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
shape,
|
|
strides,
|
|
out_w,
|
|
out_h,
|
|
w_k,
|
|
h_k,
|
|
w_stride,
|
|
h_stride,
|
|
&input,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, dst_el)
|
|
}
|
|
|
|
#[test]
|
|
fn max_pool2d_f32() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<f32> = (0..16).map(|v| v as f32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_f32",
|
|
);
|
|
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0];
|
|
assert_eq!(results, expected);
|
|
|
|
// kernel 2 stride 2
|
|
let v: Vec<f32> = (0..16).map(|v| v as f32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 2;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_f32",
|
|
);
|
|
let expected = vec![5.0, 7.0, 13.0, 15.0];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn max_pool2d_f16() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<half::f16> = (0..16).map(|v| half::f16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_f16",
|
|
);
|
|
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
|
.iter()
|
|
.map(|v| half::f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
|
|
// kernel 2 stride 2
|
|
let v: Vec<half::f16> = (0..16).map(|v| half::f16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 2;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_f16",
|
|
);
|
|
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
|
.iter()
|
|
.map(|v| half::f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn max_pool2d_bf16() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<half::bf16> = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_bf16",
|
|
);
|
|
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
|
.iter()
|
|
.map(|v| half::bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
|
|
// kernel 2 stride 2
|
|
let v: Vec<half::bf16> = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 2;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_bf16",
|
|
);
|
|
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
|
.iter()
|
|
.map(|v| half::bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn max_pool2d_u8() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<u8> = (0..16).map(|v| v as u8).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_u8",
|
|
);
|
|
let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15];
|
|
assert_eq!(results, expected);
|
|
|
|
// kernel 2 stride 2
|
|
let v: Vec<u8> = (0..16).map(|v| v as u8).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 2;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_u8",
|
|
);
|
|
let expected = vec![5, 7, 13, 15];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn max_pool2d_u32() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<u32> = (0..16).map(|v| v as u32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_u32",
|
|
);
|
|
let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15];
|
|
assert_eq!(results, expected);
|
|
|
|
// kernel 2 stride 2
|
|
let v: Vec<u32> = (0..16).map(|v| v as u32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 2;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"max_pool2d_u32",
|
|
);
|
|
let expected = vec![5, 7, 13, 15];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn avg_pool2d_f32() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<f32> = (0..16).map(|v| v as f32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"avg_pool2d_f32",
|
|
);
|
|
let expected = vec![
|
|
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
|
];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn avg_pool2d_f16() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<f16> = (0..16).map(|v| f16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"avg_pool2d_f16",
|
|
);
|
|
let expected = vec![
|
|
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
|
]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn avg_pool2d_bf16() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<bf16> = (0..16).map(|v| bf16::from_f32(v as f32)).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"avg_pool2d_bf16",
|
|
);
|
|
let expected = vec![
|
|
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
|
]
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn avg_pool2d_u8() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<u8> = (0..16).map(|v| v as u8).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"avg_pool2d_u8",
|
|
);
|
|
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn avg_pool2d_u32() {
|
|
// kernel 2 stride 1
|
|
let v: Vec<u32> = (0..16).map(|v| v as u32).collect();
|
|
let shape = vec![1, 1, 4, 4];
|
|
let strides = vec![16, 16, 4, 1];
|
|
let kernel = 2;
|
|
let stride = 1;
|
|
let results = run_pool2d(
|
|
&v,
|
|
(kernel, kernel),
|
|
(stride, stride),
|
|
&shape,
|
|
&strides,
|
|
"avg_pool2d_u32",
|
|
);
|
|
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
fn run_conv_transpose1d<T: Clone>(
|
|
input: &[T],
|
|
input_shape: &[usize],
|
|
input_stride: &[usize],
|
|
kernel: &[T],
|
|
kernel_shape: &[usize],
|
|
kernel_stride: &[usize],
|
|
dilation: usize,
|
|
stride: usize,
|
|
padding: usize,
|
|
out_padding: usize,
|
|
name: &'static str,
|
|
) -> Vec<T> {
|
|
let device = device();
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
|
|
let c_out = kernel_shape[1];
|
|
let k_size = kernel_shape[2];
|
|
let b_size = input_shape[0];
|
|
let l_in = input_shape[2];
|
|
let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1;
|
|
let dst_el = c_out * l_out * b_size;
|
|
|
|
let input = new_buffer(&device, input);
|
|
let kernel = new_buffer(&device, kernel);
|
|
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
|
|
let kernels = Kernels::new();
|
|
|
|
call_conv_transpose1d(
|
|
&device,
|
|
command_buffer,
|
|
&kernels,
|
|
name,
|
|
dilation,
|
|
stride,
|
|
padding,
|
|
out_padding,
|
|
c_out,
|
|
l_out,
|
|
b_size,
|
|
input_shape,
|
|
input_stride,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
&input,
|
|
0,
|
|
&kernel,
|
|
0,
|
|
&output,
|
|
)
|
|
.unwrap();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
read_to_vec(&output, dst_el)
|
|
}
|
|
|
|
#[test]
|
|
fn conv_transpose1d_f32() {
|
|
let input = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
let input_shape = &[1, 1, 4];
|
|
let input_stride = &[4, 4, 1];
|
|
|
|
let kernel = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
let kernel_shape = &[1, 1, 4];
|
|
let kernel_stride = &[4, 4, 1];
|
|
|
|
let results = run_conv_transpose1d(
|
|
&input,
|
|
input_shape,
|
|
input_stride,
|
|
&kernel,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
"conv_transpose1d_f32",
|
|
);
|
|
|
|
let expected = vec![1., 4., 10., 20., 25., 24., 16.];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn conv_transpose1d_f16() {
|
|
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let input_shape = &[1, 1, 4];
|
|
let input_stride = &[4, 4, 1];
|
|
|
|
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let kernel_shape = &[1, 1, 4];
|
|
let kernel_stride = &[4, 4, 1];
|
|
|
|
let results = run_conv_transpose1d(
|
|
&input,
|
|
input_shape,
|
|
input_stride,
|
|
&kernel,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
"conv_transpose1d_f16",
|
|
);
|
|
|
|
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn conv_transpose1d_bf16() {
|
|
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect();
|
|
let input_shape = &[1, 1, 4];
|
|
let input_stride = &[4, 4, 1];
|
|
|
|
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect();
|
|
let kernel_shape = &[1, 1, 4];
|
|
let kernel_stride = &[4, 4, 1];
|
|
|
|
let results = run_conv_transpose1d(
|
|
&input,
|
|
input_shape,
|
|
input_stride,
|
|
&kernel,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
"conv_transpose1d_bf16",
|
|
);
|
|
|
|
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
|
.iter()
|
|
.map(|v| bf16::from_f32(*v))
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn conv_transpose1d_u8() {
|
|
let input: Vec<u8> = vec![1, 2, 3, 4];
|
|
let input_shape = &[1, 1, 4];
|
|
let input_stride = &[4, 4, 1];
|
|
|
|
let kernel: Vec<u8> = vec![1, 2, 3, 4];
|
|
let kernel_shape = &[1, 1, 4];
|
|
let kernel_stride = &[4, 4, 1];
|
|
|
|
let results = run_conv_transpose1d(
|
|
&input,
|
|
input_shape,
|
|
input_stride,
|
|
&kernel,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
"conv_transpose1d_u8",
|
|
);
|
|
|
|
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
|
assert_eq!(results, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn conv_transpose1d_u32() {
|
|
let input: Vec<u32> = vec![1, 2, 3, 4];
|
|
let input_shape = &[1, 1, 4];
|
|
let input_stride = &[4, 4, 1];
|
|
|
|
let kernel: Vec<u32> = vec![1, 2, 3, 4];
|
|
let kernel_shape = &[1, 1, 4];
|
|
let kernel_stride = &[4, 4, 1];
|
|
|
|
let results = run_conv_transpose1d(
|
|
&input,
|
|
input_shape,
|
|
input_stride,
|
|
&kernel,
|
|
kernel_shape,
|
|
kernel_stride,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
"conv_transpose1d_u32",
|
|
);
|
|
|
|
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
|
assert_eq!(results, expected);
|
|
}
|