Refactor to simplify our lives for settings the params in the encoder.

This commit is contained in:
Nicolas Patry
2023-11-10 01:24:49 +01:00
committed by Nicolas Patry
parent 39406a6721
commit df6814f34e
6 changed files with 339 additions and 255 deletions

View File

@ -146,6 +146,7 @@ impl Device {
match (self, rhs) { match (self, rhs) {
(Self::Cpu, Self::Cpu) => true, (Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false, _ => false,
} }
} }

View File

@ -53,6 +53,8 @@ mod dummy_metal_backend;
pub mod error; pub mod error;
mod indexer; mod indexer;
pub mod layout; pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
mod mkl; mod mkl;
pub mod npy; pub mod npy;

View File

@ -1,17 +1,16 @@
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::{void_ptr, Kernels, Source}; use candle_metal_kernels::Kernels;
use core::mem; use core::mem;
use half::{bf16, f16}; use half::{bf16, f16};
use metal; use metal;
use metal::mps::matrix::encode_gemm; use metal::mps::matrix::encode_gemm;
use metal::mps::Float32; use metal::mps::Float32;
use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::Arc; use std::sync::Arc;
use tracing::debug;
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -113,7 +112,6 @@ impl BackendStorage for MetalStorage {
let device = self.device().clone(); let device = self.device().clone();
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = self.dtype; let dtype = self.dtype;
@ -174,10 +172,8 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]); stride.push(src_stride[dim_idx]);
} }
let el_to_sum_per_block = src_el / dst_el;
// The reduction loop requires the shared array to be properly initialized and for // The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two. // this we want the number of threads to be a power of two.
let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
let (name, check_empty, return_index) = match (op, self.dtype) { let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
@ -219,13 +215,10 @@ impl BackendStorage for MetalStorage {
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device(); let device = self.device();
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() { if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (self.dtype, dtype) { let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"), (left, right) => todo!("to dtype {left:?} - {right:?}"),
@ -250,12 +243,12 @@ impl BackendStorage for MetalStorage {
command_buffer.commit(); command_buffer.commit();
// command_buffer.wait_until_scheduled(); // command_buffer.wait_until_scheduled();
debug!( // debug!(
"cast {:?} - {:?} - {:?}", // "cast {:?} - {:?} - {:?}",
dtype, // dtype,
self.buffer.length(), // self.buffer.length(),
buffer.length() // buffer.length()
); // );
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -267,15 +260,8 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() { if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
@ -302,17 +288,7 @@ impl BackendStorage for MetalStorage {
} else { } else {
todo!("TODO Implement the kernel calling {}", B::KERNEL); todo!("TODO Implement the kernel calling {}", B::KERNEL);
} }
let start = std::time::Instant::now();
command_buffer.commit(); command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Unary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self { Ok(Self {
buffer, buffer,
@ -330,7 +306,6 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = lhs_l.shape(); let shape = lhs_l.shape();
let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_queue.new_command_buffer();
@ -385,17 +360,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
let start = std::time::Instant::now();
command_buffer.commit(); command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Binary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self { Ok(Self {
buffer, buffer,
@ -452,6 +417,16 @@ impl BackendStorage for MetalStorage {
todo!() todo!()
} }
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
fn conv2d( fn conv2d(
&self, &self,
_l: &Layout, _l: &Layout,
@ -504,34 +479,28 @@ impl BackendStorage for MetalStorage {
todo!() todo!()
} }
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { fn index_select(
debug!( &self,
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", _ids: &Self,
self.buffer.length(), _src_l: &Layout,
ids.buffer.length(), _ids_l: &Layout,
); _dim: usize,
let src = self; ) -> Result<Self> {
let ids_shape = ids_l.shape(); todo!("Index select");
let ids_dims = ids_shape.dims(); // let ids_shape = ids_l.shape();
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; // let left_size: usize = src_l.dims()[..dim].iter().product();
// let src = match src_l.contiguous_offsets() { // let right_size: usize = src_l.dims()[dim + 1..].iter().product();
// Some((o1, o2)) => src.slice(o1..o2), // let src_dim_size = src_l.dims()[dim];
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, // let ids_dim_size = ids_shape.elem_count();
// }; // let dst_el = ids_shape.elem_count() * left_size * right_size;
let left_size: usize = src_l.dims()[..dim].iter().product(); // let dtype = self.dtype;
let right_size: usize = src_l.dims()[dim + 1..].iter().product(); // let device = self.device();
let src_dim_size = src_l.dims()[dim]; // let buffer = device.new_buffer(dst_el, dtype);
let ids_dim_size = ids_shape.elem_count(); // Ok(Self {
let dst_el = ids_shape.elem_count() * left_size * right_size; // buffer,
let dtype = self.dtype; // device: device.clone(),
let device = self.device(); // dtype,
let buffer = device.new_buffer(dst_el, dtype); // })
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
// todo!()
} }
fn index_add( fn index_add(
@ -571,7 +540,6 @@ impl BackendStorage for MetalStorage {
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count(); let el_count = src_shape.elem_count();
if el_count == 0 { if el_count == 0 {
return Ok(()); return Ok(());
@ -637,7 +605,7 @@ impl MetalStorage {
(DType::F32, DType::F32) => { (DType::F32, DType::F32) => {
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 { if b != 1 {
debug!("TODO implement batched matmul for B={b}"); // debug!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet"); // bail!("Didn't implemented strided matmul yet");
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
@ -646,12 +614,12 @@ impl MetalStorage {
}); });
} }
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
debug!( // debug!(
"TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
lhs_l.is_contiguous(), // lhs_l.is_contiguous(),
rhs_l.is_contiguous(), // rhs_l.is_contiguous(),
rhs_l // rhs_l
); // );
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
@ -659,7 +627,7 @@ impl MetalStorage {
}); });
} }
debug!("TODO GEMM"); // debug!("TODO GEMM");
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>( encode_gemm::<Float32, Float32, Float32>(
&self.device, &self.device,

View File

@ -1859,7 +1859,11 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => { (Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
} }
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => { (Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same. // are the same.

View File

@ -1,6 +1,39 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
kernel void is_u32_f32(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device float *input,
const device uint *input_ids,
device float *output,
uint gid [[ thread_position_in_grid ]]
) {
if (gid >= dst_size) {
return;
}
const size_t id_i = gid / right_size / left_size;
const size_t right_rank_i = gid % right_size;
const size_t left_rank_i = gid % left_size;
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1));
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i;
output[gid] = input[src_i];
}
template <typename T, typename I> template <typename T, typename I>
void index_add( void index_add(
device I *ids [[buffer(0)]], device I *ids [[buffer(0)]],

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
MTLSize, Device, Function, Library, MTLSize,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -15,6 +15,70 @@ const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(usize);
primitive!(u32);
primitive!(f32);
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
(core::mem::size_of::<T>() * data.len()) as u64,
data.as_ptr() as *const T as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
set_param($encoder, _index, $param);
_index += 1;
)*
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source { pub enum Source {
Affine, Affine,
@ -191,9 +255,7 @@ pub fn call_unary_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length)); set_params!(encoder, (length, input, output));
encoder.set_buffer(1, Some(input), 0);
encoder.set_buffer(2, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -239,24 +301,19 @@ pub fn call_unary_strided(
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length)); set_params!(
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims)); encoder,
encoder.set_bytes( (
2, length,
std::mem::size_of_val(shape) as u64, num_dims,
shape.as_ptr() as *const c_void, shape,
); strides,
encoder.set_bytes( (input, offset),
3, (output, output_offset)
std::mem::size_of_val(strides) as u64, )
strides.as_ptr() as *const c_void,
); );
encoder.set_buffer(4, Some(input), offset as u64); let width: usize = shape.iter().product();
encoder.set_buffer(5, Some(output), output_offset as u64);
let width = output.length();
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
height: 1, height: 1,
@ -264,7 +321,7 @@ pub fn call_unary_strided(
}; };
let thread_group_size = MTLSize { let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1, height: 1,
depth: 1, depth: 1,
}; };
@ -299,10 +356,7 @@ pub fn call_binary_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length)); set_params!(encoder, (length, left, right, output));
encoder.set_buffer(1, Some(left), 0);
encoder.set_buffer(2, Some(right), 0);
encoder.set_buffer(3, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -348,32 +402,24 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
std::mem::size_of_val(shape) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
std::mem::size_of_val(left_strides) as u64,
left_strides.as_ptr() as *const c_void,
);
encoder.set_bytes(
4,
std::mem::size_of_val(right_strides) as u64,
right_strides.as_ptr() as *const c_void,
);
encoder.set_buffer(5, Some(left_input), left_offset as u64); set_params!(
encoder.set_buffer(6, Some(right_input), right_offset as u64); encoder,
encoder.set_buffer(7, Some(output), 0); (
length,
let width = output.length(); num_dims,
shape,
left_strides,
right_strides,
(left_input, left_offset),
(right_input, right_offset),
output
)
);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -382,7 +428,7 @@ pub fn call_binary_strided(
}; };
let thread_group_size = MTLSize { let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1, height: 1,
depth: 1, depth: 1,
}; };
@ -416,9 +462,7 @@ pub fn call_cast_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length)); set_params!(encoder, (length, input, output));
encoder.set_buffer(1, Some(input), 0);
encoder.set_buffer(2, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -463,14 +507,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&length)); set_params!(encoder, (length, elements_to_sum, input, output));
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
void_ptr(&elements_to_sum),
);
encoder.set_buffer(2, Some(input), 0);
encoder.set_buffer(3, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: out_length as u64, width: out_length as u64,
@ -518,14 +555,7 @@ pub fn call_last_softmax(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&length)); set_params!(encoder, (length, elements_to_sum, input, output));
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
void_ptr(&elements_to_sum),
);
encoder.set_buffer(2, Some(input), 0);
encoder.set_buffer(3, Some(output), 0);
let out_length = length / elements_to_sum; let out_length = length / elements_to_sum;
@ -553,10 +583,6 @@ pub fn call_last_softmax(
Ok(()) Ok(())
} }
pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
pub fn call_affine( pub fn call_affine(
device: &Device, device: &Device,
command_buffer: &CommandBufferRef, command_buffer: &CommandBufferRef,
@ -580,11 +606,7 @@ pub fn call_affine(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size)); set_params!(encoder, (size, mul, add, input, output));
encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
encoder.set_buffer(3, Some(input), 0);
encoder.set_buffer(4, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -632,36 +654,23 @@ pub fn call_where_cond_strided(
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size)); let rank = shape.len();
encoder.set_bytes(
1, set_params!(
core::mem::size_of::<usize>() as u64, encoder,
void_ptr(&shape.len()), (
size,
rank,
shape,
cond_stride,
left_stride,
right_stride,
(cond, cond_offset),
(left, left_offset),
(right, right_offset),
output
)
); );
encoder.set_bytes(
2,
std::mem::size_of_val(shape) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
std::mem::size_of_val(cond_stride) as u64,
cond_stride.as_ptr() as *const c_void,
);
encoder.set_bytes(
4,
std::mem::size_of_val(left_stride) as u64,
left_stride.as_ptr() as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of_val(right_stride) as u64,
right_stride.as_ptr() as *const c_void,
);
encoder.set_buffer(6, Some(cond), cond_offset as u64);
encoder.set_buffer(7, Some(left), left_offset as u64);
encoder.set_buffer(8, Some(right), right_offset as u64);
encoder.set_buffer(9, Some(output), 0);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: 1, width: 1,
@ -686,7 +695,13 @@ mod tests {
use super::*; use super::*;
use half::f16; use half::f16;
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
use std::mem;
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
fn device() -> Device { fn device() -> Device {
Device::system_default().unwrap() Device::system_default().unwrap()
@ -707,13 +722,8 @@ mod tests {
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let input = new_buffer(&device, v);
let input = device.new_buffer_with_data( let mut output = new_buffer(&device, v);
v.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
call_unary_contiguous( call_unary_contiguous(
&device, &device,
command_buffer, command_buffer,
@ -735,16 +745,8 @@ mod tests {
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let left = device.new_buffer_with_data( let left = new_buffer(&device, x);
x.as_ptr() as *const core::ffi::c_void, let right = new_buffer(&device, y);
std::mem::size_of_val(x) as u64,
options,
);
let right = device.new_buffer_with_data(
y.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(y) as u64,
options,
);
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous( call_binary_contiguous(
&device, &device,
@ -770,15 +772,10 @@ mod tests {
offset: usize, offset: usize,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data( let input = new_buffer(&device, v);
v.as_ptr() as *const core::ffi::c_void, let mut output = new_buffer(&device, v);
std::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
let kernels = Kernels::new(); let kernels = Kernels::new();
call_unary_strided( call_unary_strided(
&device, &device,
@ -893,13 +890,9 @@ mod tests {
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let input = new_buffer(&device, v);
let input = device.new_buffer_with_data( let mut output = new_buffer(&device, v);
v.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<U>()) as u64, options);
call_cast_contiguous( call_cast_contiguous(
&device, &device,
command_buffer, command_buffer,
@ -935,14 +928,9 @@ mod tests {
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data( let input = new_buffer(&device, v);
v.as_ptr() as *const core::ffi::c_void, let mut output = new_buffer(&device, v);
std::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
let size = v.len(); let size = v.len();
@ -978,6 +966,104 @@ mod tests {
assert_eq!(result, vec![2.6; 40_000]); assert_eq!(result, vec![2.6; 40_000]);
} }
#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
embeddings: &[T],
shape: &[usize],
ids: &[I],
dim: usize,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids.len() * left_size * right_size;
let ids_size = ids.len();
let function = library.get_function("is_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
println!("{width:?} - {:?}", grid_size);
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
dst_buffer.read_to_vec::<T>(dst_el)
}
#[test] #[test]
fn index_add() { fn index_add() {
let device = Device::system_default().expect("no device found"); let device = Device::system_default().expect("no device found");
@ -997,31 +1083,29 @@ mod tests {
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&function) .new_compute_pipeline_state_with_function(&function)
.unwrap(); .unwrap();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); let index_buffer = new_buffer(&device, &index);
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); let outputs_buffer = new_buffer(&device, &right);
encoder.set_buffer(0, Some(&index_buffer), 0); set_params!(
encoder.set_buffer(1, Some(&inputs_buffer), 0); encoder,
encoder.set_buffer(2, Some(&outputs_buffer), 0); (
&index_buffer,
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); &inputs_buffer,
encoder.set_bytes(4, 4, void_ptr(&left_size)); &outputs_buffer,
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); ids_dim_size,
encoder.set_bytes(6, 4, void_ptr(&right_size)); left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize { let grid_size = MTLSize {
width: right.len() as NSUInteger, width: right.len() as NSUInteger,
@ -1064,12 +1148,9 @@ mod tests {
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(v) as u64,
options,
);
let mut output = let mut output =
device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous( call_reduce_contiguous(
@ -1098,13 +1179,8 @@ mod tests {
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let input = new_buffer(&device, v);
let input = device.new_buffer_with_data( let mut output = new_buffer(&device, v);
v.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
call_last_softmax( call_last_softmax(
&device, &device,
command_buffer, command_buffer,