Compare commits

...

27 Commits

Author SHA1 Message Date
7e49e0af96 Tmp for allocator. 2023-11-16 12:50:41 +01:00
181d2299b2 TMp. 2023-11-16 11:41:06 +01:00
2801541e5f new_owned -> new()..to_owned(). 2023-11-16 11:07:56 +01:00
4289984d32 Remove some prints. 2023-11-13 14:51:40 +01:00
1471f98f0b BF16 metal fix. 2023-11-13 14:44:20 +01:00
dd4a40f1c0 Fixes + cache compute_pipeline_state. 2023-11-13 14:33:16 +01:00
79845bd93b Working version for llama2-c. 2023-11-13 12:36:27 +01:00
6071797450 Add erf. 2023-11-11 18:22:16 +01:00
b58b247323 Putting back f16 index select. 2023-11-11 17:43:35 +01:00
3900091e75 All tests are panicking instead of random failure. 2023-11-11 17:43:35 +01:00
54355ff997 Adding some half kernels. 2023-11-11 17:43:35 +01:00
e02f1912bb Reusing a single buffer (for now) to speed things up. 2023-11-11 17:43:35 +01:00
a52b71686b Going back on remote metal-rs. 2023-11-11 17:43:35 +01:00
7adfb70dff Few fixes. 2023-11-11 17:43:35 +01:00
3ad02147e4 Starting to fix some tests. 2023-11-11 17:43:34 +01:00
4f39695465 Missing new test. 2023-11-11 17:42:53 +01:00
4cf4844c9d Adding the test scaffolding. 2023-11-11 17:27:19 +01:00
d840838e95 Cleanup fixed a few ops removed debugging scaffolding. 2023-11-11 17:18:00 +01:00
61a070fdd1 Debugging rope. 2023-11-11 17:18:00 +01:00
e35669647d Fixed matmul (display still broken without casting back to CPU first? ) 2023-11-11 17:18:00 +01:00
53e8b7ee3e Tmp state. 2023-11-11 17:18:00 +01:00
cc26cce23c Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-11 17:18:00 +01:00
02c2ec2c71 Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-11 17:18:00 +01:00
9a2784b8ab Refactor to simplify our lives for settings the params in the encoder. 2023-11-11 17:18:00 +01:00
0f652f0e3d Adding the actual backend 2023-11-11 17:18:00 +01:00
ddee9dc1dd Remove tracing. 2023-11-11 17:18:00 +01:00
fc9bb7784a Metal part 1 - Scaffolding for metal. 2023-11-11 17:18:00 +01:00
27 changed files with 3739 additions and 53 deletions

View File

@ -13,6 +13,7 @@ members = [
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"
@ -60,7 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug]
inherits = "release"

View File

@ -13,6 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@ -29,6 +30,8 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -40,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal,
Metal { gpu_id: usize },
}
#[derive(Debug, Clone)]
@ -146,6 +146,7 @@ impl Device {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}

View File

@ -14,7 +14,9 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
_ => todo!(),
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(f, "Tensor[")?;
@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal => todo!(),
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(

View File

@ -53,6 +53,8 @@ mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;

File diff suppressed because it is too large Load Diff

View File

@ -1859,7 +1859,14 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,6 +15,12 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}

View File

@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
test_device!(
conv1d_small,
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu
conv2d_non_square_gpu,
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
test_device!(
simple_grad,
simple_grad_cpu,
simple_grad_gpu,
simple_grad_metal
);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
#[test]
fn strided_blocks() -> Result<()> {

View File

@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu
avg_pool2d_pytorch_gpu,
avg_pool2d_pytorch_metal
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu
upsample_nearest2d_gpu,
upsample_nearest2d_metal
);

View File

@ -1070,35 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
test_device!(var, var_cpu, var_gpu);
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -0,0 +1,21 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

View File

@ -0,0 +1,3 @@
# candle-metal-kernels
This crate contains Metal kernels used from candle.

View File

@ -0,0 +1,61 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -0,0 +1,72 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -0,0 +1,53 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= dim) { \
return; \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -0,0 +1,103 @@
#include <metal_stdlib>
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
) {
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,139 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)

View File

@ -0,0 +1,57 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -0,0 +1,126 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -0,0 +1,76 @@
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_affine_bench(&device, &kernels, &f32_1k);
run_affine_bench(&device, &kernels, &f32_10k);
run_affine_bench(&device, &kernels, &f32_100k);
}
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
let mul: f32 = 1.2345;
let add: f32 = 2.3456;
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
iterations,
total_time,
total_time / iterations
);
}

View File

@ -0,0 +1,182 @@
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
binary::contiguous::add::FLOAT,
binary::contiguous::sub::FLOAT,
binary::contiguous::mul::FLOAT,
binary::contiguous::div::FLOAT,
];
let f32_skernels = [
binary::strided::add::FLOAT,
binary::strided::sub::FLOAT,
binary::strided::mul::FLOAT,
binary::strided::div::FLOAT,
];
let f16_ckernels = [
binary::contiguous::add::HALF,
binary::contiguous::sub::HALF,
binary::contiguous::mul::HALF,
binary::contiguous::div::HALF,
];
let f16_skernels = [
binary::strided::add::HALF,
binary::strided::sub::HALF,
binary::strided::mul::HALF,
binary::strided::div::HALF,
];
let bf16_ckernels = [
binary::contiguous::add::BFLOAT,
binary::contiguous::sub::BFLOAT,
binary::contiguous::mul::BFLOAT,
binary::contiguous::div::BFLOAT,
];
let bf16_skernels = [
binary::strided::add::BFLOAT,
binary::strided::sub::BFLOAT,
binary::strided::mul::BFLOAT,
binary::strided::div::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_binary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [binary::contiguous::Kernel; 4],
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -0,0 +1,84 @@
use candle_metal_kernels::{call_cast_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let contiguous_kernels = ["cast_u32_f32"];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
}
fn run_cast_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided?
}

View File

@ -0,0 +1,197 @@
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
unary::contiguous::sin::FLOAT,
unary::contiguous::cos::FLOAT,
unary::contiguous::exp::FLOAT,
unary::contiguous::sqr::FLOAT,
unary::contiguous::sqrt::FLOAT,
unary::contiguous::neg::FLOAT,
unary::contiguous::copy::FLOAT,
];
let f32_skernels = [
unary::strided::sin::FLOAT,
unary::strided::cos::FLOAT,
unary::strided::exp::FLOAT,
unary::strided::sqr::FLOAT,
unary::strided::sqrt::FLOAT,
unary::strided::neg::FLOAT,
unary::strided::copy::FLOAT,
];
let f16_ckernels = [
unary::contiguous::sin::HALF,
unary::contiguous::cos::HALF,
unary::contiguous::exp::HALF,
unary::contiguous::sqr::HALF,
unary::contiguous::sqrt::HALF,
unary::contiguous::neg::HALF,
unary::contiguous::copy::HALF,
];
let f16_skernels = [
unary::strided::sin::HALF,
unary::strided::cos::HALF,
unary::strided::exp::HALF,
unary::strided::sqr::HALF,
unary::strided::sqrt::HALF,
unary::strided::neg::HALF,
unary::strided::copy::HALF,
];
let bf16_ckernels = [
unary::contiguous::sin::BFLOAT,
unary::contiguous::cos::BFLOAT,
unary::contiguous::exp::BFLOAT,
unary::contiguous::sqr::BFLOAT,
unary::contiguous::sqrt::BFLOAT,
unary::contiguous::neg::BFLOAT,
unary::contiguous::copy::BFLOAT,
];
let bf16_skernels = [
unary::strided::sin::BFLOAT,
unary::strided::cos::BFLOAT,
unary::strided::exp::BFLOAT,
unary::strided::sqr::BFLOAT,
unary::strided::sqrt::BFLOAT,
unary::strided::neg::BFLOAT,
unary::strided::copy::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_unary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [unary::contiguous::Kernel; 7],
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&mut output,
0,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}