Adding the actual backend

This commit is contained in:
Nicolas Patry
2023-11-09 19:53:14 +01:00
committed by Nicolas Patry
parent 976ad9f9c2
commit 39406a6721
13 changed files with 2612 additions and 1 deletions

View File

@ -13,6 +13,7 @@ members = [
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"

View File

@ -13,6 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal"]
metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -0,0 +1,821 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::{void_ptr, Kernels, Source};
use core::mem;
use half::{bf16, f16};
use metal;
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
use std::sync::Arc;
use tracing::debug;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.device.registry_id())
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> u64 {
self.registry_id()
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType,
}
impl BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Ok(self.clone())
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &Self::Device {
&self.device
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype {
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = self.dtype;
assert!(layout.is_contiguous());
assert_eq!(dtype, DType::F32);
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
el,
&self.buffer,
&mut buffer,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
return Ok(Self {
buffer,
device: device.clone(),
dtype,
});
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.is_contiguous());
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
let el_to_sum_per_block = src_el / dst_el;
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
_ => todo!("Reduce op for non float"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
src_el,
dst_el,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device,
dtype,
})
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
todo!()
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
}
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"cast {:?} - {:?} - {:?}",
dtype,
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Unary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Binary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn where_cond(
&self,
layout: &Layout,
t: &Self,
t_l: &Layout,
f: &Self,
f_l: &Layout,
) -> Result<Self> {
let device = self.device.clone();
let shape = t_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
"where_u8_f32",
&dims,
&self.buffer,
(layout.stride(), layout.start_offset()),
&t.buffer,
(&t_l.stride(), t_l.start_offset()),
&f.buffer,
(&f_l.stride(), f_l.start_offset()),
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device,
dtype,
})
}
fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv1D,
) -> Result<Self> {
todo!()
}
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv2D,
) -> Result<Self> {
todo!()
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
) -> Result<Self> {
todo!()
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
debug!(
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
self.buffer.length(),
ids.buffer.length(),
);
let src = self;
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
// let src = match src_l.contiguous_offsets() {
// Some((o1, o2)) => src.slice(o1..o2),
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
// };
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_size = src_l.dims()[dim];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
// todo!()
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = !rhs_l.is_contiguous();
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
if src_l.is_contiguous() {
let command_buffer = self.device.command_queue.new_command_buffer();
let blip = command_buffer.new_blit_command_encoder();
blip.copy_from_buffer(
&self.buffer,
src_l.start_offset() as u64,
&dst.buffer,
dst_offset as u64,
self.buffer.length(),
);
} else {
let command_buffer = self.device.command_queue.new_command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
dtype => todo!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
&src_l.stride(),
src_l.start_offset(),
&mut dst.buffer,
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
}
Ok(())
}
}
impl MetalStorage {
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
dtype,
}
}
pub(crate) fn matmul_generic(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
transpose_left: bool,
transpose_right: bool,
alpha: f64,
beta: f64,
) -> Result<Self> {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 {
debug!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
debug!(
"TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous(),
rhs_l
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
debug!("TODO GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&command_buffer,
transpose_left,
transpose_right,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha as f32,
beta as f32,
Some(b as NSUInteger),
)
.map_err(MetalError::from)?;
command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
_ => todo!("Unimplemented matmul for this pair"),
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
command_queue,
// command_buffer,
kernels,
})
}
fn set_seed(&self, _seed: u64) -> Result<()> {
todo!("set_seed")
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal
}
fn same_device(&self, rhs: &Self) -> bool {
self.device.registry_id() == rhs.device.registry_id()
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option,
),
};
// debug!("Allocate 2 - buffer size {}", buffer.length());
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
}

View File

@ -0,0 +1,19 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }

View File

@ -0,0 +1,3 @@
# candle-metal-kernels
This crate contains Metal kernels used from candle.

View File

@ -0,0 +1,46 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * m + a; \
} \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -0,0 +1,78 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -0,0 +1,58 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
}
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -0,0 +1,75 @@
#include <metal_stdlib>
using namespace metal;
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,124 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
kernel void fast_sum_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = 0;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
shared_memory[tid] += src[idx];
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
dst[dst_id] = shared_memory[0];
}
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}

View File

@ -0,0 +1,57 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -0,0 +1,82 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T id(T in){ return in; }
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
#endif