Compare commits

...

8 Commits

31 changed files with 3784 additions and 18 deletions

View File

@ -13,6 +13,7 @@ members = [
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"
@ -59,6 +60,8 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -13,6 +13,8 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -39,3 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
pub trait NdArray {
@ -128,10 +130,15 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
@ -139,6 +146,7 @@ impl Device {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -147,21 +155,20 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -194,6 +201,11 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
}
}
}
@ -228,6 +240,10 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
}
@ -250,6 +266,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -263,6 +283,10 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -274,6 +298,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -285,6 +314,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}

View File

@ -14,6 +14,9 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(f, "Tensor[")?;
@ -476,6 +479,9 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(

View File

@ -0,0 +1,223 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, Shape};
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -152,6 +152,9 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@ -159,6 +162,9 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -49,9 +49,12 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
@ -87,6 +90,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -0,0 +1,859 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use core::mem;
use half::{bf16, f16};
use metal;
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::Arc;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.device.registry_id())
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> NSUInteger {
self.registry_id()
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType,
}
impl BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Ok(self.clone())
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &Self::Device {
&self.device
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
// TODO Is this necessary
// self.buffer.synchronize();
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
)),
DType::U32 => Ok(CpuStorage::U32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
DType::I64 => Ok(CpuStorage::I64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
)),
DType::F16 => Ok(CpuStorage::F16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
)),
DType::BF16 => Ok(CpuStorage::BF16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
)),
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
DType::F64 => Ok(CpuStorage::F64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
)),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let el = shape.elem_count();
let dtype = self.dtype;
assert!(layout.is_contiguous());
assert_eq!(dtype, DType::F32);
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
el,
&self.buffer,
&mut buffer,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
return Ok(Self {
buffer,
device: device.clone(),
dtype,
});
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.is_contiguous());
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
_ => todo!("Reduce op for non float"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
src_el,
dst_el,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(Self {
buffer,
device,
dtype,
})
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
todo!()
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.wait_until_scheduled();
// debug!(
// "cast {:?} - {:?} - {:?}",
// dtype,
// self.buffer.length(),
// buffer.length()
// );
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn where_cond(
&self,
layout: &Layout,
t: &Self,
t_l: &Layout,
f: &Self,
f_l: &Layout,
) -> Result<Self> {
let device = self.device.clone();
let shape = t_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
"where_u8_f32",
&dims,
&self.buffer,
(layout.stride(), layout.start_offset()),
&t.buffer,
(&t_l.stride(), t_l.start_offset()),
&f.buffer,
(&f_l.stride(), f_l.start_offset()),
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(Self {
buffer,
device,
dtype,
})
}
fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv1D,
) -> Result<Self> {
todo!()
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv2D,
) -> Result<Self> {
todo!()
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
) -> Result<Self> {
todo!()
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
assert!(src_l.is_contiguous());
assert!(ids_l.is_contiguous());
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
let out = self.to_cpu_storage().unwrap();
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_queue.new_command_buffer();
// println!("INDEX SELECT");
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
&self.buffer,
&ids.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
use metal::mps::matrix::*;
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
let size = core::mem::size_of::<f32>() as NSUInteger;
let elem_count = b * m * n;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// println!("{transpose_left} {transpose_right}");
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
let command_buffer = self.device.command_queue.new_command_buffer();
matrix_multiplication.encode_to_command_buffer(
command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.commit();
command_buffer.wait_until_completed();
// let left = self.buffer.read_to_vec::<f32>(10);
// let right = rhs.buffer.read_to_vec::<f32>(10);
// let out = out_buffer.read_to_vec::<f32>(40);
// todo!("Out {left:?} {right:?} {out:?}");
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let command_buffer = self.device.command_queue.new_command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
dtype => todo!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
&src_l.stride(),
src_l.start_offset(),
&mut dst.buffer,
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// todo!("Output {:?}", dst.buffer.read_to_vec::<f32>(10));
// }
Ok(())
}
}
impl MetalStorage {
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
dtype,
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
command_queue,
// command_buffer,
kernels,
})
}
fn set_seed(&self, _seed: u64) -> Result<()> {
todo!("set_seed")
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize,
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.device.registry_id() == rhs.device.registry_id()
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
option,
),
};
// TODO is that necessary ?
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
// debug!("Allocate 2 - buffer size {}", buffer.length());
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
}

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -184,6 +184,18 @@ pub trait CustomOp1 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
@ -219,6 +231,20 @@ pub trait CustomOp2 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -261,6 +287,22 @@ pub trait CustomOp3 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
Metal(MetalStorage),
}
impl Storage {
@ -18,6 +19,10 @@ impl Storage {
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.try_clone(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -25,6 +30,7 @@ impl Storage {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
Self::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
@ -32,6 +38,7 @@ impl Storage {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
Self::Metal(storage) => storage.dtype(),
}
}
@ -65,6 +72,10 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Metal(storage))
}
}
}
@ -78,6 +89,10 @@ impl Storage {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -91,6 +106,10 @@ impl Storage {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -112,6 +131,10 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -135,6 +158,10 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Metal(storage))
}
}
}
@ -148,6 +175,10 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Metal(storage))
}
}
}
@ -161,6 +192,10 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
Self::Metal(storage) => {
let (storage, shape) = c.metal_fwd(storage, l)?;
Ok((Self::Metal(storage), shape))
}
}
}
@ -181,6 +216,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -205,6 +244,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -219,6 +262,10 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -239,6 +286,10 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -270,6 +321,10 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -324,6 +379,10 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -351,6 +410,10 @@ impl Storage {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -375,6 +438,10 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -393,6 +460,10 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -406,6 +477,10 @@ impl Storage {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Metal(storage))
}
}
}
@ -419,6 +494,10 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Metal(storage))
}
}
}
@ -442,6 +521,10 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Metal(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -468,6 +551,10 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -492,6 +579,10 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -516,6 +607,10 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -537,6 +632,10 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -564,6 +663,10 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -583,6 +686,9 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

View File

@ -6,7 +6,7 @@ use crate::op::{
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
/// Unique identifier for tensors.
@ -157,6 +157,8 @@ pub(crate) fn from_storage<S: Into<Shape>>(
) -> Tensor {
let dtype = storage.dtype();
let device = storage.device();
let shape = shape.into();
// println!("{:?} {storage:?}", shape);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)),
@ -166,7 +168,11 @@ pub(crate) fn from_storage<S: Into<Shape>>(
dtype,
device,
};
Tensor(Arc::new(tensor_))
let result = Tensor(Arc::new(tensor_));
// todo!(" from_storage");
// let result = result.to_device(&Device::Cpu).unwrap();
// todo!(" {result}");
result
}
impl Tensor {
@ -529,6 +535,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1454,6 +1461,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1484,6 +1492,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1524,6 +1533,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1841,7 +1851,11 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
@ -1849,6 +1863,9 @@ impl Tensor {
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!("not implemented yet")
}
};
let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ {

View File

@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
pub fn metal_is_available() -> bool {
cfg!(feature = "metal")
}
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}

View File

@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.get_ids()
.to_vec();
println!("{tokens:?}");
let start_gen = std::time::Instant::now();
for index in 0.. {
for index in 0..1 {
if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
// println!("Input {}", input);
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {

View File

@ -2,17 +2,28 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
let device = Device::cuda_if_available(0)?;
if !device.is_cuda() {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(device)
Ok(Device::Cpu)
}
}

View File

@ -0,0 +1,21 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

View File

@ -0,0 +1,3 @@
# candle-metal-kernels
This crate contains Metal kernels used from candle.

View File

@ -0,0 +1,75 @@
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_affine_bench(&device, &kernels, &f32_1k);
run_affine_bench(&device, &kernels, &f32_10k);
run_affine_bench(&device, &kernels, &f32_100k);
}
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
let mul: f32 = 1.2345;
let add: f32 = 2.3456;
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_affine(
&device,
command_buffer,
&kernels,
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
iterations,
total_time,
total_time / iterations
);
}

View File

@ -0,0 +1,182 @@
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
binary::contiguous::add::FLOAT,
binary::contiguous::sub::FLOAT,
binary::contiguous::mul::FLOAT,
binary::contiguous::div::FLOAT,
];
let f32_skernels = [
binary::strided::add::FLOAT,
binary::strided::sub::FLOAT,
binary::strided::mul::FLOAT,
binary::strided::div::FLOAT,
];
let f16_ckernels = [
binary::contiguous::add::HALF,
binary::contiguous::sub::HALF,
binary::contiguous::mul::HALF,
binary::contiguous::div::HALF,
];
let f16_skernels = [
binary::strided::add::HALF,
binary::strided::sub::HALF,
binary::strided::mul::HALF,
binary::strided::div::HALF,
];
let bf16_ckernels = [
binary::contiguous::add::BFLOAT,
binary::contiguous::sub::BFLOAT,
binary::contiguous::mul::BFLOAT,
binary::contiguous::div::BFLOAT,
];
let bf16_skernels = [
binary::strided::add::BFLOAT,
binary::strided::sub::BFLOAT,
binary::strided::mul::BFLOAT,
binary::strided::div::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_binary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [binary::contiguous::Kernel; 4],
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -0,0 +1,84 @@
use candle_metal_kernels::{call_cast_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let contiguous_kernels = ["cast_u32_f32"];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
}
fn run_cast_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided?
}

View File

@ -0,0 +1,197 @@
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
unary::contiguous::sin::FLOAT,
unary::contiguous::cos::FLOAT,
unary::contiguous::exp::FLOAT,
unary::contiguous::sqr::FLOAT,
unary::contiguous::sqrt::FLOAT,
unary::contiguous::neg::FLOAT,
unary::contiguous::copy::FLOAT,
];
let f32_skernels = [
unary::strided::sin::FLOAT,
unary::strided::cos::FLOAT,
unary::strided::exp::FLOAT,
unary::strided::sqr::FLOAT,
unary::strided::sqrt::FLOAT,
unary::strided::neg::FLOAT,
unary::strided::copy::FLOAT,
];
let f16_ckernels = [
unary::contiguous::sin::HALF,
unary::contiguous::cos::HALF,
unary::contiguous::exp::HALF,
unary::contiguous::sqr::HALF,
unary::contiguous::sqrt::HALF,
unary::contiguous::neg::HALF,
unary::contiguous::copy::HALF,
];
let f16_skernels = [
unary::strided::sin::HALF,
unary::strided::cos::HALF,
unary::strided::exp::HALF,
unary::strided::sqr::HALF,
unary::strided::sqrt::HALF,
unary::strided::neg::HALF,
unary::strided::copy::HALF,
];
let bf16_ckernels = [
unary::contiguous::sin::BFLOAT,
unary::contiguous::cos::BFLOAT,
unary::contiguous::exp::BFLOAT,
unary::contiguous::sqr::BFLOAT,
unary::contiguous::sqrt::BFLOAT,
unary::contiguous::neg::BFLOAT,
unary::contiguous::copy::BFLOAT,
];
let bf16_skernels = [
unary::strided::sin::BFLOAT,
unary::strided::cos::BFLOAT,
unary::strided::exp::BFLOAT,
unary::strided::sqr::BFLOAT,
unary::strided::sqrt::BFLOAT,
unary::strided::neg::BFLOAT,
unary::strided::copy::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_unary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [unary::contiguous::Kernel; 7],
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&mut output,
0,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -0,0 +1,43 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -0,0 +1,72 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -0,0 +1,51 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= dim) { \
return; \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -0,0 +1,102 @@
#include <metal_stdlib>
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = gid / right_size / left_size; \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid % left_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
) {
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
INDEX_OP(is_u32_f32, uint, float)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,139 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)

View File

@ -0,0 +1,57 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -0,0 +1,78 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T id(T in){ return in; }
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -9,6 +9,7 @@ pub struct Embedding {
impl Embedding {
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
// todo!("Embedding {embeddings}");
Self {
embeddings,
hidden_size,

View File

@ -71,11 +71,13 @@ impl PyDType {
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
Metal,
}
impl PyDevice {
@ -83,6 +85,7 @@ impl PyDevice {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
Device::Metal(_) => Self::Metal,
}
}
@ -98,6 +101,15 @@ impl PyDevice {
*device = Some(d.clone());
Ok(d)
}
Self::Metal => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
PyDevice::Metal => "metal",
};
str.to_object(py)
}

View File

@ -156,6 +156,7 @@ impl CausalSelfAttention {
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
todo!("X {x1}");
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
@ -173,6 +174,7 @@ impl CausalSelfAttention {
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
todo!("X {q}");
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
@ -295,6 +297,7 @@ impl Block {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
todo!("---X {}", x);
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
@ -327,6 +330,7 @@ impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
//println!("Embeddings {}", self.wte.embeddings());
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}