mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
26 Commits
0.9.1
...
metal_heap
Author | SHA1 | Date | |
---|---|---|---|
e8c1c31245 | |||
51f05e997d | |||
4289984d32 | |||
1471f98f0b | |||
dd4a40f1c0 | |||
79845bd93b | |||
6071797450 | |||
b58b247323 | |||
3900091e75 | |||
54355ff997 | |||
e02f1912bb | |||
a52b71686b | |||
7adfb70dff | |||
3ad02147e4 | |||
4f39695465 | |||
4cf4844c9d | |||
d840838e95 | |||
61a070fdd1 | |||
e35669647d | |||
53e8b7ee3e | |||
cc26cce23c | |||
02c2ec2c71 | |||
9a2784b8ab | |||
0f652f0e3d | |||
ddee9dc1dd | |||
fc9bb7784a |
@ -13,6 +13,7 @@ members = [
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
]
|
||||
resolver = "2"
|
||||
@ -60,7 +61,8 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../metal-rs", features = ["mps"] }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -13,6 +13,7 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal,
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -146,6 +146,7 @@ impl Device {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,9 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
_ => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -53,6 +53,8 @@ mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
|
997
candle-core/src/metal_backend.rs
Normal file
997
candle-core/src/metal_backend.rs
Normal file
@ -0,0 +1,997 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
#[error(transparent)]
|
||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
heap: metal::Heap,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn id(&self) -> NSUInteger {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
|
||||
self.command_buffer.read().unwrap()
|
||||
}
|
||||
|
||||
pub fn commit_wait_until_completed(&self) {
|
||||
let mut old = self.command_buffer.try_write().unwrap();
|
||||
let status = old.status();
|
||||
use metal::MTLCommandBufferStatus::{
|
||||
Committed, Completed, Enqueued, Error, NotEnqueued, Scheduled,
|
||||
};
|
||||
// match old.status() {}
|
||||
if old.status() == metal::MTLCommandBufferStatus::Completed {
|
||||
return;
|
||||
}
|
||||
old.commit();
|
||||
old.wait_until_completed();
|
||||
// let count = old.retain_count();
|
||||
// println!("Count {count:?}");
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
*old = command_buffer;
|
||||
// let count = old.retain_count();
|
||||
// // println!("Count after {count:?}");
|
||||
// old.release();
|
||||
// let count = old.retain_count();
|
||||
// println!("Count after release {count:?}");
|
||||
// self.command_buffer.replace_with(|_| command_buffer)
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
// println!("Creating buffer {size}");
|
||||
let buffer = self
|
||||
.heap
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeShared)
|
||||
.expect("New buffer");
|
||||
// println!("{:?}", self.heap.used_size());
|
||||
buffer
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let option = metal::MTLResourceOptions::StorageModeShared;
|
||||
// println!("Creating data buffer {size}");
|
||||
self.device
|
||||
.new_buffer_with_data(data.as_ptr() as *const core::ffi::c_void, size, option)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
self.device.commit_wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize),
|
||||
)),
|
||||
DType::U32 => Ok(CpuStorage::U32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
DType::I64 => Ok(CpuStorage::I64(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||
)),
|
||||
DType::F16 => Ok(CpuStorage::F16(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||
)),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||
)),
|
||||
DType::F32 => Ok(CpuStorage::F32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
DType::F64 => Ok(CpuStorage::F64(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
let shape = layout.shape();
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
dtype => todo!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
dtype => todo!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
assert!(sum_dims.len() == 1);
|
||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
let mut dst_el: usize = 1;
|
||||
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||
if !sum_dims.contains(&dim_idx) {
|
||||
dst_el *= d;
|
||||
dims.push(d);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||
_ => todo!("Reduce op for non float"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
src_el,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
{
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
0,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("badd", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
("badd", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("bsub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("bmul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
("bdiv", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
t_l: &Layout,
|
||||
f: &Self,
|
||||
f_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
let shape = t_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let mut buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
"where_u8_f32",
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
assert!(src_l.is_contiguous());
|
||||
assert!(src_l.start_offset() == 0);
|
||||
assert!(ids_l.is_contiguous());
|
||||
assert!(ids_l.start_offset() == 0);
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let ids_el = ids_l.shape().elem_count();
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
use metal::mps::matrix::*;
|
||||
|
||||
let (type_id, size) = match self.dtype {
|
||||
DType::F32 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||
core::mem::size_of::<f32>() as NSUInteger,
|
||||
),
|
||||
DType::F16 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||
core::mem::size_of::<f16>() as NSUInteger,
|
||||
),
|
||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||
};
|
||||
|
||||
let elem_count = b * m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
let n = n as NSUInteger;
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_descriptor = if transpose_left {
|
||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||
};
|
||||
let right_descriptor = if transpose_right {
|
||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||
};
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
for bi in 0..b {
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||
&left_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&rhs.buffer,
|
||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||
&right_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&out_buffer,
|
||||
bi * m * n * size,
|
||||
&result_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let descriptor = HeapDescriptor::new();
|
||||
let mut size =
|
||||
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
|
||||
size.size += (size.size & (size.align - 1)) + size.align;
|
||||
descriptor.set_size(size.size);
|
||||
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
|
||||
let heap = device.new_heap(&descriptor);
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
heap,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
todo!("set_seed")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||
Ok(MetalStorage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
@ -1859,7 +1859,14 @@ impl Tensor {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
);
|
||||
|
@ -1070,35 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(arange, arange_cpu, arange_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
test_device!(var, var_cpu, var_gpu);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
21
candle-metal-kernels/Cargo.toml
Normal file
21
candle-metal-kernels/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
rand = "0.8.5"
|
3
candle-metal-kernels/README.md
Normal file
3
candle-metal-kernels/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# candle-metal-kernels
|
||||
|
||||
This crate contains Metal kernels used from candle.
|
61
candle-metal-kernels/src/affine.metal
Normal file
61
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,61 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
#endif
|
72
candle-metal-kernels/src/binary.metal
Normal file
72
candle-metal-kernels/src/binary.metal
Normal file
@ -0,0 +1,72 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *left_strides, \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
#endif
|
53
candle-metal-kernels/src/cast.metal
Normal file
53
candle-metal-kernels/src/cast.metal
Normal file
@ -0,0 +1,53 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
103
candle-metal-kernels/src/indexing.metal
Normal file
103
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,103 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &ids_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) {
|
||||
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
1387
candle-metal-kernels/src/lib.rs
Normal file
1387
candle-metal-kernels/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
139
candle-metal-kernels/src/reduce.metal
Normal file
139
candle-metal-kernels/src/reduce.metal
Normal file
@ -0,0 +1,139 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
57
candle-metal-kernels/src/ternary.metal
Normal file
57
candle-metal-kernels/src/ternary.metal
Normal file
@ -0,0 +1,57 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
//
|
||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
126
candle-metal-kernels/src/unary.metal
Normal file
126
candle-metal-kernels/src/unary.metal
Normal file
@ -0,0 +1,126 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||
template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
76
candle-metal-kernels/tmp/affine.rs
Normal file
76
candle-metal-kernels/tmp/affine.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use candle_metal_kernels::{call_affine, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_affine_bench(&device, &kernels, &f32_1k);
|
||||
run_affine_bench(&device, &kernels, &f32_10k);
|
||||
run_affine_bench(&device, &kernels, &f32_100k);
|
||||
}
|
||||
|
||||
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
let mul: f32 = 1.2345;
|
||||
let add: f32 = 2.3456;
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_affine(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
mul,
|
||||
add,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
"affine",
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
182
candle-metal-kernels/tmp/binary.rs
Normal file
182
candle-metal-kernels/tmp/binary.rs
Normal file
@ -0,0 +1,182 @@
|
||||
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
binary::contiguous::add::FLOAT,
|
||||
binary::contiguous::sub::FLOAT,
|
||||
binary::contiguous::mul::FLOAT,
|
||||
binary::contiguous::div::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
binary::strided::add::FLOAT,
|
||||
binary::strided::sub::FLOAT,
|
||||
binary::strided::mul::FLOAT,
|
||||
binary::strided::div::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
binary::contiguous::add::HALF,
|
||||
binary::contiguous::sub::HALF,
|
||||
binary::contiguous::mul::HALF,
|
||||
binary::contiguous::div::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
binary::strided::add::HALF,
|
||||
binary::strided::sub::HALF,
|
||||
binary::strided::mul::HALF,
|
||||
binary::strided::div::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
binary::contiguous::add::BFLOAT,
|
||||
binary::contiguous::sub::BFLOAT,
|
||||
binary::contiguous::mul::BFLOAT,
|
||||
binary::contiguous::div::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
binary::strided::add::BFLOAT,
|
||||
binary::strided::sub::BFLOAT,
|
||||
binary::strided::mul::BFLOAT,
|
||||
binary::strided::div::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_binary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [binary::contiguous::Kernel; 4],
|
||||
strided: [binary::strided::Kernel; 4],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
84
candle-metal-kernels/tmp/cast.rs
Normal file
84
candle-metal-kernels/tmp/cast.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use candle_metal_kernels::{call_cast_contiguous, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let contiguous_kernels = ["cast_u32_f32"];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
|
||||
}
|
||||
|
||||
fn run_cast_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: &[&'static str],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_cast_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided?
|
||||
}
|
197
candle-metal-kernels/tmp/unary.rs
Normal file
197
candle-metal-kernels/tmp/unary.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
unary::contiguous::sin::FLOAT,
|
||||
unary::contiguous::cos::FLOAT,
|
||||
unary::contiguous::exp::FLOAT,
|
||||
unary::contiguous::sqr::FLOAT,
|
||||
unary::contiguous::sqrt::FLOAT,
|
||||
unary::contiguous::neg::FLOAT,
|
||||
unary::contiguous::copy::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
unary::strided::sin::FLOAT,
|
||||
unary::strided::cos::FLOAT,
|
||||
unary::strided::exp::FLOAT,
|
||||
unary::strided::sqr::FLOAT,
|
||||
unary::strided::sqrt::FLOAT,
|
||||
unary::strided::neg::FLOAT,
|
||||
unary::strided::copy::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
unary::contiguous::sin::HALF,
|
||||
unary::contiguous::cos::HALF,
|
||||
unary::contiguous::exp::HALF,
|
||||
unary::contiguous::sqr::HALF,
|
||||
unary::contiguous::sqrt::HALF,
|
||||
unary::contiguous::neg::HALF,
|
||||
unary::contiguous::copy::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
unary::strided::sin::HALF,
|
||||
unary::strided::cos::HALF,
|
||||
unary::strided::exp::HALF,
|
||||
unary::strided::sqr::HALF,
|
||||
unary::strided::sqrt::HALF,
|
||||
unary::strided::neg::HALF,
|
||||
unary::strided::copy::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
unary::contiguous::sin::BFLOAT,
|
||||
unary::contiguous::cos::BFLOAT,
|
||||
unary::contiguous::exp::BFLOAT,
|
||||
unary::contiguous::sqr::BFLOAT,
|
||||
unary::contiguous::sqrt::BFLOAT,
|
||||
unary::contiguous::neg::BFLOAT,
|
||||
unary::contiguous::copy::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
unary::strided::sin::BFLOAT,
|
||||
unary::strided::cos::BFLOAT,
|
||||
unary::strided::exp::BFLOAT,
|
||||
unary::strided::sqr::BFLOAT,
|
||||
unary::strided::sqrt::BFLOAT,
|
||||
unary::strided::neg::BFLOAT,
|
||||
unary::strided::copy::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_unary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [unary::contiguous::Kernel; 7],
|
||||
strided: [unary::strided::Kernel; 7],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user