Compare commits

..

21 Commits

Author SHA1 Message Date
eb24875856 Reworked affine and it works ? No idea how it's different. 2023-11-08 02:37:20 +01:00
3f662e54cd Reworked affine and it works ? No idea how it's different. 2023-11-08 02:34:08 +01:00
480a3e22e6 Adding cast + binary kernels. 2023-11-07 23:45:53 +01:00
0c24a885a6 Updated everything and output a trace. 2023-11-07 21:12:42 +01:00
76d3116f5d Broken metal ? 2023-11-07 14:20:13 +01:00
1367e0278b pesky bfloat type 2023-11-07 10:26:59 +01:00
7ff17d92b3 Finished the unary
- Added proper kernel type check (through modules + macro)
- split contiguous and strided into 2 different kernels
- Verified on long range + strided values.
2023-11-06 23:12:12 +01:00
cd68c96803 Going overbounds will break other kernels running from other threads. 2023-11-06 17:29:58 +01:00
4d87305c48 Float -> half / bfloat conversion in unary 2023-11-06 17:09:39 +01:00
677495f9b8 Working but failing tests because of threadgroup. 2023-11-06 17:04:47 +01:00
dedc8c3656 Writing unary as macro instead, protecting bfloat type with proper metal version. 2023-11-06 15:36:48 +01:00
63cce76b84 Improve metal kernel loading and associated errors 2023-11-06 09:48:18 +01:00
634a4e7168 BlitEncoder added to affine for copying buffer contents quickly. 2023-11-06 08:23:36 +01:00
8124d1003f Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration 2023-11-06 04:46:56 +01:00
6d4c8c0707 Use metal encode_gemm 2023-11-06 03:27:22 +01:00
e6d33a8efb Remove unused utils.metal 2023-11-06 03:26:21 +01:00
c921cc3784 Add Arc to metalstorage buffer for quick cloning 2023-11-04 09:03:23 +01:00
d4d6850c78 Impl index_add via template for all types 2023-11-04 08:46:08 +01:00
e708d35e7f index_add works 2023-11-03 21:12:52 +01:00
0794e70a19 Debugging index_add. 2023-11-03 12:09:05 +01:00
f57e3164ae Implemented cos for now. 2023-11-03 01:24:51 +01:00
16 changed files with 1728 additions and 212 deletions

View File

@ -55,8 +55,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -12,6 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
tracing = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
@ -30,7 +31,6 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, Shape};
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -163,7 +163,7 @@ pub enum Error {
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(String),
Metal(#[from] metal_backend::MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -1,28 +1,44 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::bail;
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::{void_ptr, Kernels, Source};
use core::mem;
use half::{bf16, f16};
use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions;
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
use std::sync::Arc;
use tracing::debug;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("metal error")]
Metal,
#[error("{0}")]
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
impl MetalError {
fn msg<S: AsRef<str>>(msg: S) -> Self {
MetalError::Message(msg.as_ref().to_string())
}
}
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
_command_queue: metal::CommandQueue,
command_buffer: metal::CommandBuffer,
command_queue: metal::CommandQueue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
impl std::fmt::Debug for MetalDevice {
@ -40,13 +56,20 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn metal_device(&self) -> &metal::DeviceRef {
self.device.as_ref()
}
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> u64 {
self.registry_id()
}
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
}
#[derive(Debug, Clone)]
@ -72,20 +95,44 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype{
DType::F32 => {
// self.buffer.read_to_vec(self.buffer.length() as usize / 4);
let mut buffer = vec![0.0; 32000];
buffer[0] = 1.0;
Ok(CpuStorage::F32(buffer))},
dtype => todo!("Unsupported dtype {dtype:?}")
match self.dtype {
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
}
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
println!("TODO Affine");
Ok(self.clone())
// todo!()
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = self.dtype;
assert!(layout.is_contiguous());
assert_eq!(dtype, DType::F32);
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
el,
&self.buffer,
&mut buffer,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
return Ok(Self {
buffer,
device: device.clone(),
dtype,
});
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@ -96,10 +143,78 @@ buffer[0] = 1.0;
todo!()
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
println!("TODO reduce_op");
Ok(self.clone())
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
debug!("TODO reduce_op {op:?}");
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// let el_to_sum_per_block = src_el / dst_el;
// // The reduction loop requires the shared array to be properly initialized and for
// // this we want the number of threads to be a power of two.
// let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
// let cfg = LaunchConfig {
// // TODO: Maybe use grid_y if the output is too large?
// // TODO: Specialized implementation when reducing on no or all dimensions or when
// // reducing only aggregate a small number of elements together.
// grid_dim: (dst_el as u32, 1, 1),
// block_dim: (block_dim as u32, 1, 1),
// shared_mem_bytes: 0,
// };
// let ds = dev
// .htod_copy([dims.as_slice(), stride.as_slice()].concat())
// .w()?;
// let src = &src.slice(layout.start_offset()..);
// let (name, check_empty, return_index) = match self.1 {
// ReduceOp::Sum => ("fast_sum", false, false),
// ReduceOp::Min => ("fast_min", true, false),
// ReduceOp::Max => ("fast_max", true, false),
// ReduceOp::ArgMin => ("fast_argmin", true, true),
// ReduceOp::ArgMax => ("fast_argmax", true, true),
// };
// if check_empty && layout.shape().elem_count() == 0 {
// Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
// }
// let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// if return_index {
// // SAFETY: filled in by the follow up kernel.
// let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// // SAFETY: ffi.
// unsafe { func.launch(cfg, params) }.w()?;
// Ok(S::U32(out))
// } else {
// // SAFETY: filled in by the follow up kernel.
// let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// // SAFETY: ffi.
// unsafe { func.launch(cfg, params) }.w()?;
// Ok(wrap(out))
// }
// Ok(self.clone())
// todo!()
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -107,24 +222,197 @@ buffer[0] = 1.0;
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
let device = self.device();
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"cast {:?} - {:?} - {:?} - {:?}",
dtype,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
// todo!()
// TODO
println!("TODO {:?}", B::NAME);
Ok(self.clone())
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Unary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
println!("TODO Binary {:?}", B::NAME);
Ok(self.clone())
// todo!()
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Binary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
println!("TODO where_cond");
debug!("TODO where_cond");
Ok(rhs.clone())
// todo!()
}
@ -191,9 +479,33 @@ buffer[0] = 1.0;
todo!()
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
println!("TODO Index select");
Ok(self.clone())
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
debug!(
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
self.buffer.length(),
ids.buffer.length(),
);
let src = self;
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
// let src = match src_l.contiguous_offsets() {
// Some((o1, o2)) => src.slice(o1..o2),
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
// };
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_size = src_l.dims()[dim];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
// todo!()
}
@ -232,8 +544,46 @@ buffer[0] = 1.0;
)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
println!("TODO Copy strided");
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
if src_l.is_contiguous() {
let command_buffer = self.device.command_queue.new_command_buffer();
let blip = command_buffer.new_blit_command_encoder();
blip.copy_from_buffer(
&self.buffer,
src_l.start_offset() as u64,
&dst.buffer,
dst_offset as u64,
self.buffer.length(),
);
} else {
let command_buffer = self.device.command_queue.new_command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
dtype => todo!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
&src_l.stride(),
src_l.start_offset(),
&mut dst.buffer,
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
}
Ok(())
}
}
@ -275,15 +625,9 @@ impl MetalStorage {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let span= tracing::span!(tracing::Level::TRACE, "metal alloc matmul");
let _enter = span.enter();
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 {
println!("TODO implement batched matmul for B={b}");
debug!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
return Ok(Self {
buffer: out_buffer,
@ -292,66 +636,40 @@ impl MetalStorage {
});
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
debug!(
"TODO non contiguous matmul yet {:?} {:?}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64");
// Create descriptors
let left_descriptor =
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
let right_descriptor =
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
let result_descriptor =
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
println!("lhs {:?} {m} {k}", self.buffer.length());
println!("rhs {:?} {k} {n}", rhs.buffer.length());
println!("out {:?} {m} {n}", out_buffer.length());
// Create matrix objects
let left_matrix =
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.expect("Failed to create left matrix");
let right_matrix =
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.expect("Failed to create left matrix");
let result_matrix =
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.expect("Failed to create left matrix");
println!("lhs {:?}", lhs_l.shape());
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
debug!("GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&command_buffer,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha as f32,
beta as f32,
Some(b as NSUInteger),
)
.expect("Failed to create matrix multiplication kernel");
.map_err(MetalError::from)?;
command_buffer.commit();
// command_buffer.wait_until_scheduled();
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&self.device.command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
@ -363,26 +681,31 @@ impl MetalStorage {
}
}
impl MetalDevice{
pub fn flush(&mut self){
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self._command_queue.new_owned_command_buffer();
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let _command_queue = device.new_command_queue();
let command_buffer = _command_queue.new_owned_command_buffer();
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
_command_queue,
command_buffer,
command_queue,
// command_buffer,
kernels,
})
}
@ -411,48 +734,45 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let span= tracing::span!(tracing::Level::TRACE, "metal alloc");
let _enter = span.enter();
let buffer = self.device.new_buffer(4, option);
// let buffer = match storage {
// CpuStorage::U8(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<u8>()) as u64,
// option,
// ),
// CpuStorage::U32(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<u32>()) as u64,
// option,
// ),
// CpuStorage::I64(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<i64>()) as u64,
// option,
// ),
// CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<bf16>()) as u64,
// option,
// ),
// CpuStorage::F16(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<f16>()) as u64,
// option,
// ),
// CpuStorage::F32(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<f32>()) as u64,
// option,
// ),
// CpuStorage::F64(storage) => self.device.new_buffer_with_data(
// storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<f64>()) as u64,
// option,
// ),
// };
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option,
),
};
// debug!("Allocate 2 - buffer size {}", buffer.length());
Ok(Self::Storage {
buffer,
device: self.clone(),
@ -460,13 +780,25 @@ impl BackendDevice for MetalDevice {
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)

View File

@ -1,4 +1,5 @@
use crate::{Device, Result, Shape, Tensor};
use tracing::debug;
#[cfg(target_feature = "avx")]
pub mod avx;
@ -321,7 +322,7 @@ impl crate::CustomOp1 for QTensor {
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
println!("TODO qmatmul");
debug!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
@ -349,12 +350,9 @@ impl crate::CustomOp1 for QTensor {
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{
Ok((
device.storage_from_cpu_storage(&cpu_storage)?,
dst_shape,
))
}else{
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Tensor};
use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model;
@ -232,12 +232,13 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let mut device = candle_examples::device(false)?;
let device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
tracing_subscriber::fmt::init();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -372,7 +373,8 @@ fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
logits_processor.sample(&logits)?
// logits_processor.sample(&logits)?
15043
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
@ -384,23 +386,21 @@ fn main() -> anyhow::Result<()> {
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
if let candle::Device::Metal(device) = &mut device{
device.flush()
}
let logits = logits.squeeze(0)?;
// let logits = if args.repeat_penalty == 1. {
// logits
// } else {
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
// candle_transformers::utils::apply_repeat_penalty(
// &logits,
// args.repeat_penalty,
// &all_tokens[start_at..],
// )?
// };
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?;
// let logits = logits.ones_like()?;
// next_token = logits_processor.sample(&logits)?;
let next_token = 15043;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {

View File

@ -10,3 +10,8 @@ license.workspace = true
[dependencies]
metal = { workspace = true }
once_cell = "1.18.0"
thiserror = { workspace = true }
[dev-dependencies]
half = { workspace = true }

View File

@ -0,0 +1,44 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * mul + add; \
} \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -0,0 +1,78 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -0,0 +1,58 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
}
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -0,0 +1,75 @@
#include <metal_stdlib>
using namespace metal;
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

View File

@ -1 +1,862 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Indexing,
Unary,
Binary,
Cast,
}
macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
pub struct Kernel(pub(crate) &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
}
)+
}
pub mod strided {
pub struct Kernel(pub(crate) &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
}
)+
}
};
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
}
pub mod binary {
ops!(add, sub, mul, div);
}
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
// let mut l = HashMap::new();
// l.insert("affine", AFFINE);
// l.insert("indexing", INDEXING);
// l.insert("unary", UNARY);
// l
// });
//
#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Could not lock kernel map: {0}")]
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
#[error("Error while loading function: {0}")]
LoadFunctionError(String),
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
fn from(e: std::sync::PoisonError<T>) -> Self {
Self::LockError(e.to_string())
}
}
type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
type Functions = KernelMap<Function>;
#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
funcs: RwLock<Functions>,
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
let funcs = RwLock::new(Functions::new());
Self { libraries, funcs }
}
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
// let kernels = Self::new();
// kernels.load_libraries(device)?;
// Ok(kernels)
// }
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
// for name in LIBRARY_SOURCES.keys() {
// self.load_library(device, name)?;
// }
// Ok(())
// }
fn get_library_source(&self, source: Source) -> &'static str {
// LIBRARY_SOURCES.get(name).cloned()
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
}
}
pub fn load_library(
&self,
device: &Device,
source: Source,
) -> Result<Library, MetalKernelError> {
let mut libraries = self.libraries.write()?;
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_content = self.get_library_source(source);
let lib = device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
libraries.insert(source, lib.clone());
Ok(lib)
}
}
pub fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<Function, MetalKernelError> {
let mut funcs = self.funcs.write()?;
if let Some(func) = funcs.get(name) {
Ok(func.clone())
} else {
let func = self
.load_library(device, source)?
.get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
funcs.insert(name, func.clone());
Ok(func)
}
}
}
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: &Buffer,
strides: &[usize],
offset: usize,
output: &mut Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Unary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len() as usize;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
(shape.len() * std::mem::size_of::<usize>()) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
(strides.len() * std::mem::size_of::<usize>()) as u64,
strides.as_ptr() as *const c_void,
);
encoder.set_buffer(4, Some(&input), offset as u64);
encoder.set_buffer(5, Some(&output), output_offset as u64);
let width = output.length();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
left: &Buffer,
right: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&left), 0);
encoder.set_buffer(2, Some(&right), 0);
encoder.set_buffer(3, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
left_input: &Buffer,
left_strides: &[usize],
left_offset: usize,
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Binary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len() as usize;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
(shape.len() * std::mem::size_of::<usize>()) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
(left_strides.len() * std::mem::size_of::<usize>()) as u64,
left_strides.as_ptr() as *const c_void,
);
encoder.set_bytes(
4,
(right_strides.len() * std::mem::size_of::<usize>()) as u64,
right_strides.as_ptr() as *const c_void,
);
encoder.set_buffer(5, Some(&left_input), left_offset as u64);
encoder.set_buffer(6, Some(&right_input), right_offset as u64);
encoder.set_buffer(7, Some(&output), 0);
let width = output.length();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_cast_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
size: usize,
input: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
encoder.set_buffer(3, Some(&input), 0);
encoder.set_buffer(4, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
use std::mem;
fn device() -> Device {
Device::system_default().unwrap()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
call_unary_contiguous(
&device,
&command_buffer,
&kernels,
name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = device.new_buffer_with_data(
x.as_ptr() as *const core::ffi::c_void,
(x.len() * core::mem::size_of::<T>()) as u64,
options,
);
let right = device.new_buffer_with_data(
y.as_ptr() as *const core::ffi::c_void,
(y.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((x.len() * core::mem::size_of::<T>()) as u64, options);
call_binary_contiguous(
&device,
&command_buffer,
&kernels,
name,
x.len(),
&left,
&right,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(x.len())
}
fn run_strided<T: Clone>(
v: &[T],
kernel: unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Vec<T> {
let device = device();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let kernels = Kernels::new();
call_unary_strided(
&device,
&command_buffer,
&kernels,
kernel,
shape,
&input,
strides,
offset,
&mut output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Shape = [6], strides = [1];
let shape = vec![6];
let strides = vec![1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Contiguous
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Transposed
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![1, 3];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Very large
let v = vec![1.0f32; 10_000];
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
let expected: Vec<_> = left
.iter()
.zip(right.iter())
.map(|(&x, &y)| x + y)
.collect();
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<U>()) as u64, options);
call_cast_contiguous(
&device,
&command_buffer,
&kernels,
name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<U>(v.len())
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let size = v.len();
call_affine(
&device,
&command_buffer,
&kernels,
size,
&input,
&mut output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn affine() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
let input = [1.0f32; 40_000];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6; 40_000]);
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(&inputs_buffer), 0);
encoder.set_buffer(2, Some(&outputs_buffer), 0);
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
encoder.set_bytes(4, 4, void_ptr(&left_size));
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
encoder.set_bytes(6, 4, void_ptr(&right_size));
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
}
}

View File

@ -0,0 +1,82 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T id(T in){ return in; }
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
#endif

View File

@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }

View File

@ -1,5 +1,6 @@
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
use tracing::debug;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -198,7 +199,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
println!("TODO softmax-last-dim");
debug!("TODO softmax-last-dim");
Ok((storage.clone(), layout.shape().clone()))
}
}

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
@ -79,8 +79,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
let _enter = span.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self
.cos
@ -90,37 +88,21 @@ impl LayerWeights {
.sin
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
let _enter = span.enter();
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
drop(_enter);
// This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
let _enter = span.enter();
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
let _enter = span.enter();
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
let _enter = span.enter();
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
let _enter = span.enter();
let rope = rope.flatten_from(D::Minus2)?;
drop(_enter);
Ok(rope)
}
@ -214,15 +196,15 @@ fn precomput_freqs_cis(
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
let idx_theta = Tensor::new(range.as_slice(), device)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
// .to_dtype(DType::F32)?
// let idx_theta = Tensor::new(range.as_slice(), device)?
// .reshape((MAX_SEQ_LEN, 1))?
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))