mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
27 Commits
0.9.0-alph
...
tmp5
Author | SHA1 | Date | |
---|---|---|---|
7e49e0af96 | |||
181d2299b2 | |||
2801541e5f | |||
4289984d32 | |||
1471f98f0b | |||
dd4a40f1c0 | |||
79845bd93b | |||
6071797450 | |||
b58b247323 | |||
3900091e75 | |||
54355ff997 | |||
e02f1912bb | |||
a52b71686b | |||
7adfb70dff | |||
3ad02147e4 | |||
4f39695465 | |||
4cf4844c9d | |||
d840838e95 | |||
61a070fdd1 | |||
e35669647d | |||
53e8b7ee3e | |||
cc26cce23c | |||
02c2ec2c71 | |||
9a2784b8ab | |||
0f652f0e3d | |||
ddee9dc1dd | |||
fc9bb7784a |
@ -13,6 +13,7 @@ members = [
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
]
|
||||
resolver = "2"
|
||||
@ -60,7 +61,10 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../metal-rs", features = ["mps"] }
|
||||
dispatch = "0.2.0"
|
||||
rustc-hash = "1.1"
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -13,6 +13,7 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
@ -29,6 +30,8 @@ safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
dispatch = { workspace = true, optional = true }
|
||||
rustc-hash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -40,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||
|
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal,
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -146,6 +146,7 @@ impl Device {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,9 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
_ => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -53,6 +53,8 @@ mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
|
1028
candle-core/src/metal_backend.rs
Normal file
1028
candle-core/src/metal_backend.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1859,7 +1859,14 @@ impl Tensor {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
);
|
||||
|
@ -1070,35 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(arange, arange_cpu, arange_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
test_device!(var, var_cpu, var_gpu);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
21
candle-metal-kernels/Cargo.toml
Normal file
21
candle-metal-kernels/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
rand = "0.8.5"
|
3
candle-metal-kernels/README.md
Normal file
3
candle-metal-kernels/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# candle-metal-kernels
|
||||
|
||||
This crate contains Metal kernels used from candle.
|
61
candle-metal-kernels/src/affine.metal
Normal file
61
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,61 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
#endif
|
72
candle-metal-kernels/src/binary.metal
Normal file
72
candle-metal-kernels/src/binary.metal
Normal file
@ -0,0 +1,72 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *left_strides, \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
#endif
|
53
candle-metal-kernels/src/cast.metal
Normal file
53
candle-metal-kernels/src/cast.metal
Normal file
@ -0,0 +1,53 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
103
candle-metal-kernels/src/indexing.metal
Normal file
103
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,103 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &ids_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) {
|
||||
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
1389
candle-metal-kernels/src/lib.rs
Normal file
1389
candle-metal-kernels/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
139
candle-metal-kernels/src/reduce.metal
Normal file
139
candle-metal-kernels/src/reduce.metal
Normal file
@ -0,0 +1,139 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
57
candle-metal-kernels/src/ternary.metal
Normal file
57
candle-metal-kernels/src/ternary.metal
Normal file
@ -0,0 +1,57 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
//
|
||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
126
candle-metal-kernels/src/unary.metal
Normal file
126
candle-metal-kernels/src/unary.metal
Normal file
@ -0,0 +1,126 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||
template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
76
candle-metal-kernels/tmp/affine.rs
Normal file
76
candle-metal-kernels/tmp/affine.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use candle_metal_kernels::{call_affine, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_affine_bench(&device, &kernels, &f32_1k);
|
||||
run_affine_bench(&device, &kernels, &f32_10k);
|
||||
run_affine_bench(&device, &kernels, &f32_100k);
|
||||
}
|
||||
|
||||
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
let mul: f32 = 1.2345;
|
||||
let add: f32 = 2.3456;
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_affine(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
mul,
|
||||
add,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
"affine",
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
182
candle-metal-kernels/tmp/binary.rs
Normal file
182
candle-metal-kernels/tmp/binary.rs
Normal file
@ -0,0 +1,182 @@
|
||||
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
binary::contiguous::add::FLOAT,
|
||||
binary::contiguous::sub::FLOAT,
|
||||
binary::contiguous::mul::FLOAT,
|
||||
binary::contiguous::div::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
binary::strided::add::FLOAT,
|
||||
binary::strided::sub::FLOAT,
|
||||
binary::strided::mul::FLOAT,
|
||||
binary::strided::div::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
binary::contiguous::add::HALF,
|
||||
binary::contiguous::sub::HALF,
|
||||
binary::contiguous::mul::HALF,
|
||||
binary::contiguous::div::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
binary::strided::add::HALF,
|
||||
binary::strided::sub::HALF,
|
||||
binary::strided::mul::HALF,
|
||||
binary::strided::div::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
binary::contiguous::add::BFLOAT,
|
||||
binary::contiguous::sub::BFLOAT,
|
||||
binary::contiguous::mul::BFLOAT,
|
||||
binary::contiguous::div::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
binary::strided::add::BFLOAT,
|
||||
binary::strided::sub::BFLOAT,
|
||||
binary::strided::mul::BFLOAT,
|
||||
binary::strided::div::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_binary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [binary::contiguous::Kernel; 4],
|
||||
strided: [binary::strided::Kernel; 4],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
84
candle-metal-kernels/tmp/cast.rs
Normal file
84
candle-metal-kernels/tmp/cast.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use candle_metal_kernels::{call_cast_contiguous, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let contiguous_kernels = ["cast_u32_f32"];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
|
||||
}
|
||||
|
||||
fn run_cast_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: &[&'static str],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_cast_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided?
|
||||
}
|
197
candle-metal-kernels/tmp/unary.rs
Normal file
197
candle-metal-kernels/tmp/unary.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
unary::contiguous::sin::FLOAT,
|
||||
unary::contiguous::cos::FLOAT,
|
||||
unary::contiguous::exp::FLOAT,
|
||||
unary::contiguous::sqr::FLOAT,
|
||||
unary::contiguous::sqrt::FLOAT,
|
||||
unary::contiguous::neg::FLOAT,
|
||||
unary::contiguous::copy::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
unary::strided::sin::FLOAT,
|
||||
unary::strided::cos::FLOAT,
|
||||
unary::strided::exp::FLOAT,
|
||||
unary::strided::sqr::FLOAT,
|
||||
unary::strided::sqrt::FLOAT,
|
||||
unary::strided::neg::FLOAT,
|
||||
unary::strided::copy::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
unary::contiguous::sin::HALF,
|
||||
unary::contiguous::cos::HALF,
|
||||
unary::contiguous::exp::HALF,
|
||||
unary::contiguous::sqr::HALF,
|
||||
unary::contiguous::sqrt::HALF,
|
||||
unary::contiguous::neg::HALF,
|
||||
unary::contiguous::copy::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
unary::strided::sin::HALF,
|
||||
unary::strided::cos::HALF,
|
||||
unary::strided::exp::HALF,
|
||||
unary::strided::sqr::HALF,
|
||||
unary::strided::sqrt::HALF,
|
||||
unary::strided::neg::HALF,
|
||||
unary::strided::copy::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
unary::contiguous::sin::BFLOAT,
|
||||
unary::contiguous::cos::BFLOAT,
|
||||
unary::contiguous::exp::BFLOAT,
|
||||
unary::contiguous::sqr::BFLOAT,
|
||||
unary::contiguous::sqrt::BFLOAT,
|
||||
unary::contiguous::neg::BFLOAT,
|
||||
unary::contiguous::copy::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
unary::strided::sin::BFLOAT,
|
||||
unary::strided::cos::BFLOAT,
|
||||
unary::strided::exp::BFLOAT,
|
||||
unary::strided::sqr::BFLOAT,
|
||||
unary::strided::sqrt::BFLOAT,
|
||||
unary::strided::neg::BFLOAT,
|
||||
unary::strided::copy::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_unary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [unary::contiguous::Kernel; 7],
|
||||
strided: [unary::strided::Kernel; 7],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user