mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Compare commits
1 Commits
metal4-mfa
...
sort
Author | SHA1 | Date | |
---|---|---|---|
03ad494fcd |
@ -5,13 +5,43 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle::{CpuStorage, Device, Layout, Shape, Tensor};
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
struct ArgSort;
|
||||||
|
impl candle::CustomOp1 for ArgSort {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"arg-sort"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &CpuStorage,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> candle::Result<(CpuStorage, Shape)> {
|
||||||
|
if layout.shape().rank() != 1 {
|
||||||
|
candle::bail!(
|
||||||
|
"input should have a single dimension, got {:?}",
|
||||||
|
layout.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let slice = storage.as_slice::<f32>()?;
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &slice[o1..o2],
|
||||||
|
};
|
||||||
|
let mut dst = (0..src.len() as u32).collect::<Vec<u32>>();
|
||||||
|
dst.sort_by(|&i, &j| src[i as usize].total_cmp(&src[j as usize]));
|
||||||
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
|
Ok((storage, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
let a = Tensor::new(&[0.0f32, 1.0, 3.0, 2.0, -12.0, 4.0, 3.5], &Device::Cpu)?;
|
||||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
let indices = a.apply_op1(ArgSort)?;
|
||||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
let a_sorted = a.gather(&indices, 0)?;
|
||||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
println!("{indices}");
|
||||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
println!("{a_sorted}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -4,13 +4,11 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use half::f16;
|
use core::mem;
|
||||||
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use std::sync::Arc;
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::sync::{Arc, RwLock};
|
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -38,9 +36,7 @@ impl From<String> for MetalError {
|
|||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -62,48 +58,10 @@ impl MetalDevice {
|
|||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
|
|
||||||
self.command_buffer.try_read().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn commit(&self) {
|
|
||||||
let mut old = self.command_buffer.try_write().unwrap();
|
|
||||||
match old.status() {
|
|
||||||
metal::MTLCommandBufferStatus::NotEnqueued
|
|
||||||
| metal::MTLCommandBufferStatus::Enqueued => {
|
|
||||||
old.commit();
|
|
||||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*old = command_buffer;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) {
|
|
||||||
let mut old = self.command_buffer.try_write().unwrap();
|
|
||||||
match old.status() {
|
|
||||||
metal::MTLCommandBufferStatus::NotEnqueued
|
|
||||||
| metal::MTLCommandBufferStatus::Enqueued => {
|
|
||||||
old.commit();
|
|
||||||
old.wait_until_completed();
|
|
||||||
}
|
|
||||||
metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => {
|
|
||||||
old.wait_until_completed();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*old = command_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
&self.kernels
|
&self.kernels
|
||||||
}
|
}
|
||||||
@ -112,107 +70,16 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
|
self.device
|
||||||
}
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
|
|
||||||
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> {
|
|
||||||
let mut buffers = self.buffers.try_write().unwrap();
|
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
|
||||||
|
|
||||||
for sub in &mut *subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
return sub.clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
new_buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
|
||||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
|
||||||
let tmp = self.device.new_buffer_with_data(
|
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
|
||||||
size,
|
|
||||||
metal::MTLResourceOptions::StorageModeManaged,
|
|
||||||
);
|
|
||||||
let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate);
|
|
||||||
{
|
|
||||||
let command = self.command_buffer();
|
|
||||||
let blit = command.new_blit_command_encoder();
|
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
}
|
|
||||||
// This is necessary, for mmaped safetensors
|
|
||||||
// Because of the unsafe slice cast we're doing.
|
|
||||||
// The slice might not live long enough for metal
|
|
||||||
// To actually fill the GPU buffer.
|
|
||||||
// Putting this wait forces the GPU buffer to be filled
|
|
||||||
// with the actual data allowing the CPU storage todo
|
|
||||||
// deallocate properly.
|
|
||||||
self.wait_until_completed();
|
|
||||||
real
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_matrix(
|
|
||||||
&self,
|
|
||||||
(b, m, n): (NSUInteger, NSUInteger, NSUInteger),
|
|
||||||
size: NSUInteger,
|
|
||||||
type_id: u32,
|
|
||||||
dtype: DType,
|
|
||||||
) -> Result<(Matrix, Arc<Buffer>)> {
|
|
||||||
let elem_count = (b * m * n) as usize;
|
|
||||||
let out_buffer = self.new_buffer(elem_count, dtype);
|
|
||||||
|
|
||||||
let result_descriptor =
|
|
||||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
|
||||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
|
||||||
})?;
|
|
||||||
Ok((result_matrix, out_buffer))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let capture = metal::CaptureManager::shared();
|
|
||||||
let descriptor = metal::CaptureDescriptor::new();
|
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
|
||||||
descriptor.set_capture_device(&self);
|
|
||||||
descriptor.set_output_url(path);
|
|
||||||
|
|
||||||
capture
|
|
||||||
.start_capture(&descriptor)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: Arc<metal::Buffer>,
|
buffer: metal::Buffer,
|
||||||
matrices: Arc<
|
|
||||||
RwLock<
|
|
||||||
HashMap<
|
|
||||||
(
|
|
||||||
NSUInteger,
|
|
||||||
NSUInteger,
|
|
||||||
NSUInteger,
|
|
||||||
bool,
|
|
||||||
NSUInteger,
|
|
||||||
NSUInteger,
|
|
||||||
u32,
|
|
||||||
),
|
|
||||||
Matrix,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -241,23 +108,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
self.dtype
|
self.dtype
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
|
||||||
let command_buffer = self.device.command_buffer();
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
drop(command_buffer);
|
|
||||||
self.device.wait_until_completed();
|
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
|
DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
|
||||||
DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
|
DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
|
||||||
DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
|
DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
|
||||||
DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
|
DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
|
||||||
DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
|
DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
|
||||||
DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))),
|
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
|
||||||
DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
|
DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,48 +126,30 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype);
|
if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 {
|
||||||
let command_buffer = self.device.command_buffer();
|
crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
|
||||||
let name = match self.dtype {
|
|
||||||
DType::F32 => "affine_float",
|
|
||||||
DType::F16 => "affine_half",
|
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_affine(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
name,
|
|
||||||
el,
|
|
||||||
&self.buffer,
|
|
||||||
&buffer,
|
|
||||||
mul as f32,
|
|
||||||
add as f32,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
} else {
|
|
||||||
let name = match self.dtype {
|
|
||||||
DType::F32 => "affine_float_strided",
|
|
||||||
DType::F16 => "affine_half_strided",
|
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_affine_strided(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
name,
|
|
||||||
layout.dims(),
|
|
||||||
&self.buffer,
|
|
||||||
layout.stride(),
|
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
|
||||||
mul as f32,
|
|
||||||
add as f32,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
}
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
|
||||||
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
candle_metal_kernels::call_affine(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
el,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
return Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
@ -323,11 +163,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
if !(sum_dims.len() == 1
|
if !(sum_dims.len() == 1
|
||||||
&& sum_dims[0] == layout.shape().rank() - 1
|
&& sum_dims[0] == layout.shape().rank() - 1
|
||||||
&& layout.stride()[sum_dims[0]] == 1)
|
&& layout.is_contiguous()
|
||||||
|
&& layout.start_offset() == 0)
|
||||||
{
|
{
|
||||||
crate::bail!("Non last dim reduce op not supported yet");
|
crate::bail!("Non contiguous reduce op not supported yet");
|
||||||
}
|
}
|
||||||
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
@ -362,11 +202,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
if dtype == DType::U32 {
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
crate::bail!("Implement return index reduce op");
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
}
|
|
||||||
let buffer = device.new_buffer(dst_el, dtype);
|
|
||||||
let command_buffer = self.device.command_buffer();
|
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -375,12 +212,17 @@ impl BackendStorage for MetalStorage {
|
|||||||
src_el,
|
src_el,
|
||||||
dst_el,
|
dst_el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
&mut buffer,
|
||||||
&buffer,
|
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
@ -391,15 +233,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
let device = self.device();
|
let device = self.device();
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
|
||||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_contiguous(
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
@ -409,34 +247,24 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
&mut buffer,
|
||||||
&buffer,
|
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
crate::bail!(
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
self.dtype,
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
dtype
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
);
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
|
||||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_cast_strided(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
kernel_name,
|
|
||||||
layout.dims(),
|
|
||||||
&self.buffer,
|
|
||||||
layout.stride(),
|
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
@ -444,8 +272,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
@ -457,25 +285,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
|
||||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
|
||||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
|
||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
|
||||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
|
||||||
("usin", DType::F16) => contiguous::sin::HALF,
|
|
||||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
|
||||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
|
||||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
|
||||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
|
||||||
("ulog", DType::F16) => contiguous::log::HALF,
|
|
||||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
|
||||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
|
||||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
|
||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
@ -485,58 +294,20 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
use candle_metal_kernels::unary::strided;
|
crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
|
||||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
|
||||||
("usin", DType::F32) => strided::sin::FLOAT,
|
|
||||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
|
||||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
|
||||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
|
||||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
|
||||||
("ulog", DType::F32) => strided::log::FLOAT,
|
|
||||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
|
||||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
|
||||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
|
||||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
|
||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
|
||||||
("ucos", DType::F16) => strided::cos::HALF,
|
|
||||||
("usin", DType::F16) => strided::sin::HALF,
|
|
||||||
("usqr", DType::F16) => strided::sqr::HALF,
|
|
||||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
|
||||||
("uneg", DType::F16) => strided::neg::HALF,
|
|
||||||
("uexp", DType::F16) => strided::exp::HALF,
|
|
||||||
("ulog", DType::F16) => strided::log::HALF,
|
|
||||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
|
||||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
|
||||||
("uerf", DType::F16) => strided::erf::HALF,
|
|
||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_unary_strided(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
kernel_name,
|
|
||||||
layout.dims(),
|
|
||||||
&self.buffer,
|
|
||||||
layout.stride(),
|
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
}
|
||||||
command_buffer.set_label("unary");
|
command_buffer.commit();
|
||||||
drop(command_buffer);
|
command_buffer.wait_until_completed();
|
||||||
self.device.commit();
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(
|
fn binary_impl<B: BinaryOpT>(
|
||||||
@ -549,8 +320,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = lhs_l.shape();
|
let shape = lhs_l.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||||
{
|
{
|
||||||
@ -565,14 +336,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
("div", DType::F32) => contiguous::div::FLOAT,
|
("div", DType::F32) => contiguous::div::FLOAT,
|
||||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||||
("add", DType::F16) => contiguous::add::HALF,
|
|
||||||
("badd", DType::F16) => contiguous::add::HALF,
|
|
||||||
("sub", DType::F16) => contiguous::sub::HALF,
|
|
||||||
("bsub", DType::F16) => contiguous::sub::HALF,
|
|
||||||
("mul", DType::F16) => contiguous::mul::HALF,
|
|
||||||
("bmul", DType::F16) => contiguous::mul::HALF,
|
|
||||||
("div", DType::F16) => contiguous::div::HALF,
|
|
||||||
("bdiv", DType::F16) => contiguous::div::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
@ -583,7 +346,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -594,10 +357,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||||
("badd", DType::F16) => strided::add::HALF,
|
|
||||||
("bsub", DType::F16) => strided::sub::HALF,
|
|
||||||
("bmul", DType::F16) => strided::mul::HALF,
|
|
||||||
("bdiv", DType::F16) => strided::div::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
@ -607,19 +366,23 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
lhs_l.dims(),
|
lhs_l.dims(),
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
lhs_l.stride(),
|
&lhs_l.stride(),
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
rhs_l.stride(),
|
&rhs_l.stride(),
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.set_label("binary");
|
command_buffer.commit();
|
||||||
drop(command_buffer);
|
command_buffer.wait_until_completed();
|
||||||
self.device.commit();
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_cond(
|
fn where_cond(
|
||||||
@ -635,22 +398,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = t.dtype;
|
let dtype = t.dtype;
|
||||||
let buffer = self.device.new_buffer(el, dtype);
|
let mut buffer = self.device.new_buffer(el, dtype);
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
if t.dtype() != f.dtype() {
|
|
||||||
crate::bail!("Invalid ternary different dtypes for values");
|
|
||||||
}
|
|
||||||
let name = match (self.dtype, t.dtype()) {
|
|
||||||
(DType::U8, DType::F32) => "where_u8_f32",
|
|
||||||
(DType::U8, DType::F16) => "where_u8_f16",
|
|
||||||
(left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
"where_u8_f32",
|
||||||
dims,
|
&dims,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
(
|
(
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
@ -660,10 +415,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||||
&f.buffer,
|
&f.buffer,
|
||||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, device, dtype))
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
@ -752,13 +513,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dst_el = ids_el * left_size * right_size;
|
let dst_el = ids_el * left_size * right_size;
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
|
||||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -769,10 +529,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
dim,
|
dim,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&ids.buffer,
|
&ids.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_add(
|
fn index_add(
|
||||||
@ -795,19 +561,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// Create descriptors
|
// Create descriptors
|
||||||
let (type_id, size, name) = match self.dtype {
|
use metal::mps::matrix::*;
|
||||||
DType::F32 => (
|
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
||||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
let size = core::mem::size_of::<f32>() as NSUInteger;
|
||||||
core::mem::size_of::<f32>() as NSUInteger,
|
|
||||||
"sgemm",
|
let elem_count = b * m * n;
|
||||||
),
|
|
||||||
DType::F16 => (
|
|
||||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
|
||||||
core::mem::size_of::<f16>() as NSUInteger,
|
|
||||||
"hgemm",
|
|
||||||
),
|
|
||||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
@ -839,130 +597,120 @@ impl BackendStorage for MetalStorage {
|
|||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
|
|
||||||
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
|
let b = b as NSUInteger;
|
||||||
|
let m = m as NSUInteger;
|
||||||
|
let n = n as NSUInteger;
|
||||||
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let left_descriptor = if transpose_left {
|
||||||
|
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||||
|
};
|
||||||
|
let right_descriptor = if transpose_right {
|
||||||
|
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||||
|
};
|
||||||
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
command_buffer.set_label("mfa gemm");
|
// Create matrix objects
|
||||||
|
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
candle_metal_kernels::call_mfa_gemm(
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
&self.device.device,
|
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
||||||
&command_buffer,
|
.ok_or_else(|| {
|
||||||
&self.device.kernels,
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
name,
|
})?;
|
||||||
&self.buffer,
|
|
||||||
lhs_l.shape().dims(),
|
let alpha = 1.0f64;
|
||||||
&rhs.buffer,
|
let beta = 0.0f64;
|
||||||
rhs_l.shape().dims(),
|
// Create kernel
|
||||||
&result_buffer,
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
(b, m, n, k),
|
&self.device,
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
drop(command_buffer);
|
matrix_multiplication.set_batch_size(b);
|
||||||
self.device.commit();
|
|
||||||
|
|
||||||
Ok(Self::new(
|
// Encode kernel to command buffer
|
||||||
self.buffer.clone(),
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
self.device.clone(),
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
self.dtype(),
|
command_buffer,
|
||||||
))
|
&left_matrix,
|
||||||
|
&right_matrix,
|
||||||
|
&result_matrix,
|
||||||
|
);
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer: out_buffer,
|
||||||
|
device: self.device.clone(),
|
||||||
|
dtype: self.dtype(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let command_buffer = self.device.command_buffer();
|
let src_shape = src_l.shape();
|
||||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
let el_count = src_shape.elem_count();
|
||||||
command_buffer.set_label("copy_contiguous");
|
if el_count == 0 {
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
return Ok(());
|
||||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
|
||||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
|
||||||
blit.copy_from_buffer(
|
|
||||||
&self.buffer,
|
|
||||||
src_offset,
|
|
||||||
dst.buffer(),
|
|
||||||
dst_offset,
|
|
||||||
self.buffer.length() - src_offset,
|
|
||||||
);
|
|
||||||
blit.end_encoding();
|
|
||||||
} else {
|
|
||||||
let src_shape = src_l.shape();
|
|
||||||
let el_count = src_shape.elem_count();
|
|
||||||
if el_count == 0 {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let kernel_name = match self.dtype {
|
|
||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
|
||||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
|
||||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
|
||||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
|
||||||
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_unary_strided(
|
|
||||||
&self.device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.device.kernels,
|
|
||||||
kernel_name,
|
|
||||||
src_l.dims(),
|
|
||||||
&self.buffer,
|
|
||||||
src_l.stride(),
|
|
||||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&dst.buffer,
|
|
||||||
dst_offset * dst.dtype.size_in_bytes(),
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
command_buffer.set_label("copy_strided");
|
|
||||||
}
|
}
|
||||||
drop(command_buffer);
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
self.device.commit();
|
let kernel_name = match self.dtype {
|
||||||
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
|
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
src_l.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
&src_l.stride(),
|
||||||
|
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&mut dst.buffer,
|
||||||
|
dst_offset,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MetalStorage {
|
impl MetalStorage {
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||||
let matrices = Arc::new(RwLock::new(HashMap::new()));
|
|
||||||
Self {
|
Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
matrices,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matrix(
|
|
||||||
&self,
|
|
||||||
(b, m, n): (NSUInteger, NSUInteger, NSUInteger),
|
|
||||||
transpose: bool,
|
|
||||||
size: NSUInteger,
|
|
||||||
offset: NSUInteger,
|
|
||||||
type_id: u32,
|
|
||||||
) -> Result<Matrix> {
|
|
||||||
let key = (b, m, n, transpose, size, offset, type_id);
|
|
||||||
|
|
||||||
let mut matrices = self.matrices.try_write().unwrap();
|
|
||||||
if let Some(matrix) = matrices.get(&key) {
|
|
||||||
Ok(matrix.clone())
|
|
||||||
} else {
|
|
||||||
let descriptor = if transpose {
|
|
||||||
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
|
|
||||||
} else {
|
|
||||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
|
|
||||||
};
|
|
||||||
let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
|
||||||
})?;
|
|
||||||
matrices.insert(key, matrix.clone());
|
|
||||||
Ok(matrix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendDevice for MetalDevice {
|
impl BackendDevice for MetalDevice {
|
||||||
@ -972,14 +720,10 @@ impl BackendDevice for MetalDevice {
|
|||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
|
||||||
buffers,
|
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -999,8 +743,9 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
// TODO Is there a faster way ?
|
||||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||||
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
@ -1010,20 +755,49 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||||
let buffer = match storage {
|
let buffer = match storage {
|
||||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
option,
|
||||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
),
|
||||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
};
|
};
|
||||||
Ok(Self::Storage::new(
|
Ok(Self::Storage {
|
||||||
buffer.into(),
|
buffer,
|
||||||
self.clone(),
|
device: self.clone(),
|
||||||
storage.dtype(),
|
dtype: storage.dtype(),
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(
|
fn rand_uniform(
|
||||||
|
@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
|||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
@ -11,7 +11,6 @@ license = "MIT OR Apache-2.0"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
metal-flash-attention = { path = "../../../metal-flash-attention" }
|
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -50,7 +50,6 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"affine_float",
|
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.0,
|
kernel_name.to_string(),
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
||||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
let shape = vec![2, 5_000];
|
let shape = vec![2, 5_000];
|
||||||
let strides = vec![2, 1];
|
let strides = vec![2, 1];
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
for kernel_name in &strided {
|
for kernel_name in strided {
|
||||||
let total_time = autoreleasepool(|| {
|
let total_time = autoreleasepool(|| {
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.0,
|
kernel_name.to_string(),
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
@ -33,24 +33,6 @@ kernel void FN_NAME( \
|
|||||||
const TYPENAME a = TYPENAME(add); \
|
const TYPENAME a = TYPENAME(add); \
|
||||||
output[id] = input[id] * m + a; \
|
output[id] = input[id] * m + a; \
|
||||||
} \
|
} \
|
||||||
kernel void FN_NAME##_strided( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
constant float &mul, \
|
|
||||||
constant float &add, \
|
|
||||||
device const TYPENAME *input, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint id [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (id >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
const TYPENAME m = TYPENAME(mul); \
|
|
||||||
const TYPENAME a = TYPENAME(add); \
|
|
||||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
AFFINE(affine_float, float)
|
AFFINE(affine_float, float)
|
||||||
AFFINE(affine_half, half)
|
AFFINE(affine_half, half)
|
||||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
|||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (tid >= dim) { \
|
if (thread_position_in_grid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||||
} \
|
} \
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -37,19 +37,15 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (tid >= dim) { \
|
if (i >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
|
||||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
|
||||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
|
||||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
#endif
|
#endif
|
||||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
|||||||
if (gid >= dst_size) { \
|
if (gid >= dst_size) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
const size_t id_i = (gid / right_size) % ids_size; \
|
const size_t id_i = gid / right_size / left_size; \
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
|
||||||
const size_t right_rank_i = gid % right_size; \
|
const size_t right_rank_i = gid % right_size; \
|
||||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
const size_t left_rank_i = gid % left_size; \
|
||||||
/* \
|
/* \
|
||||||
// Force prevent out of bounds indexing \
|
// Force prevent out of bounds indexing \
|
||||||
// since there doesn't seem to be a good way to force crash \
|
// since there doesn't seem to be a good way to force crash \
|
||||||
// No need to check for zero we're only allowing unsized. \
|
// No need to check for zero we're only allowing unsized. \
|
||||||
*/ \
|
*/ \
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||||
|
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||||
output[gid] = input[src_i]; \
|
output[gid] = input[src_i]; \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +75,6 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceUsage, MTLSize,
|
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||||
NSUInteger,
|
|
||||||
};
|
};
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::hash::Hash;
|
|
||||||
use std::io::{stdout, Write};
|
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
@ -16,7 +13,6 @@ const BINARY: &str = include_str!("binary.metal");
|
|||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const MFA_LIB: &[u8] = include_bytes!("mfa.metallib");
|
|
||||||
|
|
||||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||||
let size = length as u64;
|
let size = length as u64;
|
||||||
@ -63,8 +59,8 @@ impl<T> EncoderParam for &[T] {
|
|||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
encoder.set_bytes(
|
encoder.set_bytes(
|
||||||
position,
|
position,
|
||||||
core::mem::size_of_val(data) as u64,
|
(core::mem::size_of::<T>() * data.len()) as u64,
|
||||||
data.as_ptr() as *const c_void,
|
data.as_ptr() as *const T as *const c_void,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -109,14 +105,19 @@ pub enum Source {
|
|||||||
Ternary,
|
Ternary,
|
||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
MetalFlashAttention,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! ops{
|
macro_rules! ops{
|
||||||
($($name:ident),+) => {
|
($($name:ident),+) => {
|
||||||
|
|
||||||
pub mod contiguous {
|
pub mod contiguous {
|
||||||
pub struct Kernel(pub &'static str);
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
|
impl std::fmt::Display for Kernel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -125,18 +126,16 @@ macro_rules! ops{
|
|||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
|
||||||
use super::Kernel;
|
|
||||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
|
||||||
pub const HALF: Kernel = Kernel("copy_half");
|
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
|
||||||
pub const U32: Kernel = Kernel("copy_u32");
|
|
||||||
pub const U8: Kernel = Kernel("copy_u8");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod strided {
|
pub mod strided {
|
||||||
pub struct Kernel(pub &'static str);
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
|
impl std::fmt::Display for Kernel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -145,20 +144,12 @@ macro_rules! ops{
|
|||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
|
||||||
use super::Kernel;
|
|
||||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
|
||||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
|
||||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
|
||||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
|
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div);
|
||||||
@ -170,12 +161,8 @@ pub enum MetalKernelError {
|
|||||||
LockError(String),
|
LockError(String),
|
||||||
#[error("Error while loading library: {0}")]
|
#[error("Error while loading library: {0}")]
|
||||||
LoadLibraryError(String),
|
LoadLibraryError(String),
|
||||||
#[error("Error while loading function: {0:?}")]
|
#[error("Error while loading function: {0}")]
|
||||||
LoadFunctionError(String),
|
LoadFunctionError(String),
|
||||||
#[error("Failed to create compute function")]
|
|
||||||
FailedToCreateComputeFunction,
|
|
||||||
#[error("Failed to create pipeline")]
|
|
||||||
FailedToCreatePipeline(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||||
@ -184,52 +171,32 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type KernelMap<T> = HashMap<KernelKey, T>;
|
type KernelMap<T> = HashMap<&'static str, T>;
|
||||||
type Libraries = HashMap<Source, Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Pipelines = KernelMap<ComputePipelineState>;
|
type Functions = KernelMap<Function>;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
pipelines: RwLock<Pipelines>,
|
funcs: RwLock<Functions>,
|
||||||
}
|
|
||||||
|
|
||||||
enum LibraryDefinition {
|
|
||||||
Source(&'static str),
|
|
||||||
Data(&'static [u8]),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&'static str> for LibraryDefinition {
|
|
||||||
fn from(s: &'static str) -> Self {
|
|
||||||
Self::Source(s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl From<&'static [u8]> for LibraryDefinition {
|
|
||||||
fn from(s: &'static [u8]) -> Self {
|
|
||||||
Self::Data(s)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let libraries = RwLock::new(Libraries::new());
|
let libraries = RwLock::new(Libraries::new());
|
||||||
let pipelines = RwLock::new(Pipelines::new());
|
let funcs = RwLock::new(Functions::new());
|
||||||
Self {
|
Self { libraries, funcs }
|
||||||
libraries,
|
|
||||||
pipelines,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_library_source(&self, source: Source) -> LibraryDefinition {
|
fn get_library_source(&self, source: Source) -> &'static str {
|
||||||
match source {
|
match source {
|
||||||
Source::Affine => AFFINE.into(),
|
Source::Affine => AFFINE,
|
||||||
Source::Unary => UNARY.into(),
|
Source::Unary => UNARY,
|
||||||
Source::Binary => BINARY.into(),
|
Source::Binary => BINARY,
|
||||||
Source::Ternary => TERNARY.into(),
|
Source::Ternary => TERNARY,
|
||||||
Source::Indexing => INDEXING.into(),
|
Source::Indexing => INDEXING,
|
||||||
Source::Cast => CAST.into(),
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE.into(),
|
Source::Reduce => REDUCE,
|
||||||
Source::MetalFlashAttention => MFA_LIB.into(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,204 +209,31 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = match self.get_library_source(source) {
|
let source_content = self.get_library_source(source);
|
||||||
LibraryDefinition::Source(source_content) => device
|
let lib = device
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||||
LibraryDefinition::Data(data) => device
|
|
||||||
.new_library_with_data(data)
|
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_function(
|
pub fn load_function(
|
||||||
&self,
|
&self,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
source: Source,
|
source: Source,
|
||||||
key: KernelKey,
|
name: &'static str,
|
||||||
) -> Result<Function, MetalKernelError> {
|
) -> Result<Function, MetalKernelError> {
|
||||||
let func = self
|
let mut funcs = self.funcs.write()?;
|
||||||
.load_library(device, source)?
|
if let Some(func) = funcs.get(name) {
|
||||||
.get_function(
|
Ok(func.clone())
|
||||||
key.name,
|
|
||||||
key.constants.map(|c| c.create_function_constant_values()),
|
|
||||||
)
|
|
||||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
|
||||||
Ok(func)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_pipeline<T: Into<KernelKey>>(
|
|
||||||
&self,
|
|
||||||
device: &Device,
|
|
||||||
source: Source,
|
|
||||||
key: T,
|
|
||||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
|
||||||
let key: KernelKey = key.into();
|
|
||||||
let mut pipelines = self.pipelines.write()?;
|
|
||||||
if let Some(pipeline) = pipelines.get(&key) {
|
|
||||||
Ok(pipeline.clone())
|
|
||||||
} else {
|
} else {
|
||||||
let func = self.load_function(device, source, key.clone())?;
|
let func = self
|
||||||
let pipeline = device
|
.load_library(device, source)?
|
||||||
.new_compute_pipeline_state_with_function(&func)
|
.get_function(name, None)
|
||||||
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||||
pipelines.insert(key, pipeline.clone());
|
funcs.insert(name, func.clone());
|
||||||
|
Ok(func)
|
||||||
Ok(pipeline)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
||||||
struct KernelKey {
|
|
||||||
name: &'static str,
|
|
||||||
constants: Option<ConstantValues>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl KernelKey {
|
|
||||||
fn new(name: &'static str) -> Self {
|
|
||||||
Self {
|
|
||||||
name,
|
|
||||||
constants: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn with_constants(mut self, constants: ConstantValues) -> Self {
|
|
||||||
self.constants = Some(constants);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
||||||
enum ConstantValueId {
|
|
||||||
Index(NSUInteger),
|
|
||||||
Name(&'static str),
|
|
||||||
}
|
|
||||||
|
|
||||||
trait MetalDType {
|
|
||||||
const MTL_DATA_TYPE: MTLDataType;
|
|
||||||
}
|
|
||||||
macro_rules! metal_dtype {
|
|
||||||
($ty:ty, $mtl_data_type:ident) => {
|
|
||||||
impl MetalDType for $ty {
|
|
||||||
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
metal_dtype!(f32, Float);
|
|
||||||
metal_dtype!(u32, UInt);
|
|
||||||
metal_dtype!(u16, UShort);
|
|
||||||
metal_dtype!(bool, Bool);
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
enum ConstantValueType {
|
|
||||||
Float(f32),
|
|
||||||
Uint(u32),
|
|
||||||
UShort(u16),
|
|
||||||
Bool(bool),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Hash for ConstantValueType {
|
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
||||||
use ConstantValueType::*;
|
|
||||||
match self {
|
|
||||||
Float(v) => v.to_bits().hash(state),
|
|
||||||
Uint(v) => v.hash(state),
|
|
||||||
UShort(v) => v.hash(state),
|
|
||||||
Bool(v) => v.hash(state),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Eq for ConstantValueType {}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
struct ConstantValues(BTreeMap<ConstantValueId, ConstantValueType>);
|
|
||||||
|
|
||||||
macro_rules! add_indexed_constant {
|
|
||||||
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
|
|
||||||
$fcv.set_constant_value_at_index(
|
|
||||||
$value as *const $ty as *const c_void,
|
|
||||||
<$ty>::MTL_DATA_TYPE,
|
|
||||||
$idx,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
macro_rules! add_named_constant {
|
|
||||||
($fcv:expr, $value:expr, $ty:ty, $name:expr) => {
|
|
||||||
$fcv.set_constant_value_with_name(
|
|
||||||
$value as *const $ty as *const c_void,
|
|
||||||
<$ty>::MTL_DATA_TYPE,
|
|
||||||
$name,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Hash for ConstantValues {
|
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
||||||
for (id, value) in &self.0 {
|
|
||||||
id.hash(state);
|
|
||||||
value.hash(state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConstantValues {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self(BTreeMap::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set(mut self, id: impl Into<ConstantValueId>, value: impl Into<ConstantValueType>) -> Self {
|
|
||||||
self.0.insert(id.into(), value.into());
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_function_constant_values(&self) -> FunctionConstantValues {
|
|
||||||
use ConstantValueId::*;
|
|
||||||
use ConstantValueType::*;
|
|
||||||
let mut function_values = FunctionConstantValues::new();
|
|
||||||
|
|
||||||
for (id, value) in &self.0 {
|
|
||||||
match (&id, &value) {
|
|
||||||
(Index(index), Float(value)) => {
|
|
||||||
add_indexed_constant!(function_values, value, f32, *index);
|
|
||||||
}
|
|
||||||
(Index(index), Uint(value)) => {
|
|
||||||
add_indexed_constant!(function_values, value, u32, *index);
|
|
||||||
}
|
|
||||||
(Index(index), UShort(value)) => {
|
|
||||||
add_indexed_constant!(function_values, value, u16, *index);
|
|
||||||
}
|
|
||||||
(Index(index), Bool(value)) => {
|
|
||||||
add_indexed_constant!(function_values, value, bool, *index);
|
|
||||||
}
|
|
||||||
(Name(name), Float(value)) => {
|
|
||||||
add_named_constant!(function_values, value, f32, name);
|
|
||||||
}
|
|
||||||
(Name(name), Uint(value)) => {
|
|
||||||
add_named_constant!(function_values, value, u32, name);
|
|
||||||
}
|
|
||||||
(Name(name), UShort(value)) => {
|
|
||||||
add_named_constant!(function_values, value, u16, name);
|
|
||||||
}
|
|
||||||
(Name(name), Bool(value)) => {
|
|
||||||
add_named_constant!(function_values, value, bool, name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&'static str> for KernelKey {
|
|
||||||
fn from(name: &'static str) -> Self {
|
|
||||||
Self {
|
|
||||||
name,
|
|
||||||
constants: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -452,9 +246,18 @@ pub fn call_unary_contiguous(
|
|||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -476,10 +279,18 @@ pub fn call_unary_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
output_offset: usize,
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -515,9 +326,17 @@ pub fn call_binary_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
left: &Buffer,
|
left: &Buffer,
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -544,9 +363,17 @@ pub fn call_binary_strided(
|
|||||||
right_input: &Buffer,
|
right_input: &Buffer,
|
||||||
right_strides: &[usize],
|
right_strides: &[usize],
|
||||||
right_offset: usize,
|
right_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -584,60 +411,31 @@ pub fn call_cast_contiguous(
|
|||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
output: &mut Buffer,
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_cast_strided(
|
|
||||||
device: &Device,
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
kernels: &Kernels,
|
|
||||||
kernel_name: &'static str,
|
|
||||||
shape: &[usize],
|
|
||||||
input: &Buffer,
|
|
||||||
input_strides: &[usize],
|
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
length,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_strides,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_reduce_contiguous(
|
pub fn call_reduce_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -646,19 +444,24 @@ pub fn call_reduce_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
output: &mut Buffer,
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||||
encoder,
|
|
||||||
(length, elements_to_sum, (input, input_offset), output)
|
|
||||||
);
|
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as u64,
|
||||||
@ -692,9 +495,18 @@ pub fn call_last_softmax(
|
|||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements_to_sum: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -730,14 +542,21 @@ pub fn call_affine(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -751,45 +570,6 @@ pub fn call_affine(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_affine_strided(
|
|
||||||
device: &Device,
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
shape: &[usize],
|
|
||||||
input: &Buffer,
|
|
||||||
input_stride: &[usize],
|
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
|
||||||
mul: f32,
|
|
||||||
add: f32,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
||||||
let size: usize = shape.iter().product();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
size,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_stride,
|
|
||||||
mul,
|
|
||||||
add,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_where_cond_strided(
|
pub fn call_where_cond_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -802,9 +582,17 @@ pub fn call_where_cond_strided(
|
|||||||
(left_stride, left_offset): (&[usize], usize),
|
(left_stride, left_offset): (&[usize], usize),
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let func = kernels.load_function(device, Source::Ternary, name)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -846,14 +634,17 @@ pub fn call_index_select(
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
let src_dim_size = shape[dim];
|
let src_dim_size = shape[dim];
|
||||||
let dst_el = ids_size * left_size * right_size;
|
let dst_el = ids_size * left_size * right_size;
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&func)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
@ -880,230 +671,5 @@ pub fn call_index_select(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<NSUInteger> for ConstantValueId {
|
|
||||||
fn from(idx: NSUInteger) -> Self {
|
|
||||||
Self::Index(idx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<usize> for ConstantValueId {
|
|
||||||
fn from(idx: usize) -> Self {
|
|
||||||
ConstantValueId::from(idx as NSUInteger)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<i32> for ConstantValueId {
|
|
||||||
fn from(idx: i32) -> Self {
|
|
||||||
ConstantValueId::from(idx as NSUInteger)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&'static str> for ConstantValueId {
|
|
||||||
fn from(name: &'static str) -> Self {
|
|
||||||
Self::Name(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! to_constant_value {
|
|
||||||
($ty:ty, $constant_value_type:ident) => {
|
|
||||||
to_constant_value!($ty, $ty, $constant_value_type);
|
|
||||||
};
|
|
||||||
($ty:ty, $via:ty, $constant_value_type:ident) => {
|
|
||||||
impl From<$ty> for ConstantValueType {
|
|
||||||
fn from(v: $ty) -> Self {
|
|
||||||
Self::$constant_value_type(v as $via)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
to_constant_value!(f32, Float);
|
|
||||||
to_constant_value!(u32, Uint);
|
|
||||||
to_constant_value!(usize, u32, Uint);
|
|
||||||
to_constant_value!(u16, UShort);
|
|
||||||
to_constant_value!(bool, Bool);
|
|
||||||
|
|
||||||
struct MFAGemmConfig {
|
|
||||||
m: usize,
|
|
||||||
k: usize,
|
|
||||||
n: usize,
|
|
||||||
transpose_left: bool,
|
|
||||||
transpose_right: bool,
|
|
||||||
batched: bool,
|
|
||||||
m_simd: u16,
|
|
||||||
n_simd: u16,
|
|
||||||
k_simd: u16,
|
|
||||||
m_splits: u16,
|
|
||||||
n_splits: u16,
|
|
||||||
m_group: u16,
|
|
||||||
n_group: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<MFAGemmConfig> for ConstantValues {
|
|
||||||
fn from(conf: MFAGemmConfig) -> Self {
|
|
||||||
ConstantValues::new()
|
|
||||||
.set(0, conf.m)
|
|
||||||
.set(1, conf.k)
|
|
||||||
.set(2, conf.n)
|
|
||||||
.set(10, conf.transpose_left)
|
|
||||||
.set(11, conf.transpose_right)
|
|
||||||
.set(12, false)
|
|
||||||
.set(20, 1.0)
|
|
||||||
.set(21, 0.0)
|
|
||||||
.set(100, conf.batched)
|
|
||||||
.set(101, false)
|
|
||||||
.set(50001, false)
|
|
||||||
.set(200, conf.m_simd)
|
|
||||||
.set(201, conf.n_simd)
|
|
||||||
.set(202, conf.k_simd)
|
|
||||||
.set(210, conf.m_splits)
|
|
||||||
.set(211, conf.n_splits)
|
|
||||||
// garbage
|
|
||||||
.set(102, false)
|
|
||||||
.set(103, false)
|
|
||||||
.set(113, false)
|
|
||||||
.set(50000, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_mfa_gemm(
|
|
||||||
device: &Device,
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
lhs: &Buffer,
|
|
||||||
lhs_dims: &[usize],
|
|
||||||
rhs: &Buffer,
|
|
||||||
rhs_dims: &[usize],
|
|
||||||
output: &Buffer,
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
transpose_left: bool,
|
|
||||||
transpose_right: bool,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let batched = b > 1;
|
|
||||||
|
|
||||||
let mut c_elements = m * n;
|
|
||||||
if batched {
|
|
||||||
c_elements *= 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
let is_half = name == "hgemm";
|
|
||||||
let is_float = name == "sgemm";
|
|
||||||
|
|
||||||
let mut m_group = 32;
|
|
||||||
let mut n_group = 32;
|
|
||||||
let mut k_simd = 32;
|
|
||||||
if c_elements > 10 ^ 6 {
|
|
||||||
m_group = 48;
|
|
||||||
n_group = 48;
|
|
||||||
}
|
|
||||||
// If K_simd is perfectly equal to matrix K, the compiler can elide a large
|
|
||||||
// amount of logic in the kernel.
|
|
||||||
if k >= 33 && k <= 40 {
|
|
||||||
k_simd = 40;
|
|
||||||
} else if is_half && k >= 73 && k >= 80 {
|
|
||||||
k_simd = 80;
|
|
||||||
} else if c_elements > 10 ^ 6 {
|
|
||||||
if k <= 16 {
|
|
||||||
k_simd = 16;
|
|
||||||
} else if k <= 24 {
|
|
||||||
k_simd = 24;
|
|
||||||
} else if k <= 32 {
|
|
||||||
k_simd = 32;
|
|
||||||
} else if k <= 48 {
|
|
||||||
k_simd = 24;
|
|
||||||
} else if k <= 64 {
|
|
||||||
k_simd = 32;
|
|
||||||
} else if is_float {
|
|
||||||
k_simd = 24;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let m_splits = 2;
|
|
||||||
let n_splits = 2;
|
|
||||||
let m_simd = m_group / m_splits;
|
|
||||||
let n_simd = n_group / n_splits;
|
|
||||||
|
|
||||||
let config = MFAGemmConfig {
|
|
||||||
m,
|
|
||||||
k,
|
|
||||||
n,
|
|
||||||
transpose_left,
|
|
||||||
transpose_right,
|
|
||||||
batched,
|
|
||||||
m_simd,
|
|
||||||
n_simd,
|
|
||||||
k_simd,
|
|
||||||
m_splits,
|
|
||||||
n_splits,
|
|
||||||
m_group,
|
|
||||||
n_group,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(
|
|
||||||
device,
|
|
||||||
Source::MetalFlashAttention,
|
|
||||||
KernelKey::new(name).with_constants(config.into()),
|
|
||||||
)?;
|
|
||||||
let block_type_size = if is_half { 2 } else { 4 };
|
|
||||||
let a_block_bytes = m_group * k_simd * block_type_size;
|
|
||||||
let b_block_bytes = k_simd * n_group * block_type_size;
|
|
||||||
let c_block_bytes = m_group * n_group * block_type_size;
|
|
||||||
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
|
|
||||||
|
|
||||||
if m % 8 > 0 && n % 8 > 0 {
|
|
||||||
thread_group_memory_length = core::cmp::max(thread_group_memory_length, c_block_bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
|
|
||||||
encoder.use_resources(&[&lhs, &rhs], MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(&output, MTLResourceUsage::Write);
|
|
||||||
encoder.set_buffers(0, &[Some(lhs), Some(rhs), Some(output)], &[0; 3]);
|
|
||||||
|
|
||||||
let ceil_divide = |a, b| (a + b - 1) / b;
|
|
||||||
|
|
||||||
let mut grid_z = 1;
|
|
||||||
|
|
||||||
if batched {
|
|
||||||
grid_z = lhs_dims[..lhs_dims.len() - 2].iter().product();
|
|
||||||
let byte_stride = |shape: &[usize]| -> u64 {
|
|
||||||
let rank = shape.len();
|
|
||||||
let mut output = core::mem::size_of::<f32>() * shape[rank - 2] * shape[rank - 1];
|
|
||||||
if shape[..shape.len() - 2].iter().product::<usize>() == 1 {
|
|
||||||
output = 0;
|
|
||||||
}
|
|
||||||
output as u64
|
|
||||||
};
|
|
||||||
let byte_stride_a = byte_stride(lhs_dims);
|
|
||||||
let byte_stride_b = byte_stride(rhs_dims);
|
|
||||||
let byte_stride_c = byte_stride(&[m, n]);
|
|
||||||
|
|
||||||
type BatchConfig = (u64, u64, u64, u64);
|
|
||||||
let mut batching_conf: Vec<BatchConfig> = vec![];
|
|
||||||
for i in 0..grid_z {
|
|
||||||
batching_conf.push((
|
|
||||||
i as u64 * byte_stride_a,
|
|
||||||
i as u64 * byte_stride_b,
|
|
||||||
i as u64 * byte_stride_c,
|
|
||||||
0u64,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
set_param(encoder, 10, batching_conf.as_slice());
|
|
||||||
}
|
|
||||||
|
|
||||||
let grid_size = MTLSize::new(
|
|
||||||
ceil_divide(n as NSUInteger, n_group as NSUInteger),
|
|
||||||
ceil_divide(m as NSUInteger, m_group as NSUInteger),
|
|
||||||
grid_z as NSUInteger,
|
|
||||||
);
|
|
||||||
|
|
||||||
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1);
|
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
Binary file not shown.
@ -1,8 +1,6 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -18,18 +16,18 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 1024;
|
constant int THREADGROUP_SIZE = 256;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, T) \
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
constant size_t &el_to_sum_per_block, \
|
constant size_t &el_to_sum_per_block, \
|
||||||
device const T *src, \
|
device const TYPENAME *src, \
|
||||||
device T *dst, \
|
device TYPENAME *dst, \
|
||||||
uint id [[ thread_position_in_grid ]], \
|
uint id [[ thread_position_in_grid ]], \
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
uint blockDim [[ threads_per_threadgroup ]] \
|
||||||
) { \
|
) { \
|
||||||
\
|
\
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
@ -47,10 +45,10 @@ kernel void NAME( \
|
|||||||
// TODO: Fast version for the contiguous case. \
|
// TODO: Fast version for the contiguous case. \
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
*/ \
|
*/ \
|
||||||
T x = shared_memory[tid]; \
|
TYPENAME x = shared_memory[tid]; \
|
||||||
T y = src[idx]; \
|
TYPENAME y = src[idx]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
idx += block_dim; \
|
idx += blockDim; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -58,10 +56,10 @@ kernel void NAME( \
|
|||||||
/* \
|
/* \
|
||||||
// reduction in shared memory \
|
// reduction in shared memory \
|
||||||
*/ \
|
*/ \
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||||
if (tid < s) { \
|
if (tid < s) { \
|
||||||
T x = shared_memory[tid]; \
|
TYPENAME x = shared_memory[tid]; \
|
||||||
T y = shared_memory[tid + s]; \
|
TYPENAME y = shared_memory[tid + s]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
} \
|
} \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -70,74 +68,72 @@ kernel void NAME( \
|
|||||||
dst[dst_id] = shared_memory[0]; \
|
dst[dst_id] = shared_memory[0]; \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
kernel void softmax_float(
|
||||||
|
constant size_t &src_numel,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const float *src,
|
||||||
|
device float *dst,
|
||||||
|
uint id [[ thread_position_in_grid ]],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint blockDim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||||
|
|
||||||
|
shared_memory[tid] = -INFINITY;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
float max = shared_memory[0];
|
||||||
|
|
||||||
|
shared_memory[tid] = 0;
|
||||||
|
|
||||||
|
// Restart
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
const float val = exp(src[idx] - max);
|
||||||
|
dst[idx] = val;
|
||||||
|
shared_memory[tid] += val;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] += shared_memory[tid + s];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float inv_acc = 1/shared_memory[0];
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
dst[idx] *= inv_acc;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_float, float)
|
REDUCE(x + y, fast_sum_float, float)
|
||||||
REDUCE(x * y, fast_mul_float, float)
|
REDUCE(x * y, fast_mul_float, float)
|
||||||
REDUCE(max(x, y), fast_max_float, float)
|
REDUCE(max(x, y), fast_max_float, float)
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &src_numel, \
|
|
||||||
constant size_t &el_to_sum_per_block, \
|
|
||||||
device const T *src, \
|
|
||||||
device T *dst, \
|
|
||||||
\
|
|
||||||
uint id [[ thread_position_in_grid ]], \
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
|
||||||
) { \
|
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
|
||||||
shared_memory[tid] = -INFINITY; \
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
|
||||||
size_t idx = start_idx + tid; \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
float _max = shared_memory[0]; \
|
|
||||||
\
|
|
||||||
shared_memory[tid] = 0; \
|
|
||||||
\
|
|
||||||
idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
const T val = T(exp(src[idx] - _max)); \
|
|
||||||
dst[idx] = val; \
|
|
||||||
shared_memory[tid] += val; \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
shared_memory[tid] += shared_memory[tid + s]; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
const T inv_acc = T(1/shared_memory[0]); \
|
|
||||||
idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
dst[idx] *= inv_acc; \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
|
|
||||||
SOFTMAX(softmax_float, float)
|
|
||||||
SOFTMAX(softmax_half, half)
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
SOFTMAX(softmax_bfloat, bfloat)
|
|
||||||
#endif
|
|
||||||
|
@ -32,9 +32,6 @@ kernel void FN_NAME( \
|
|||||||
device TYPENAME *out ,\
|
device TYPENAME *out ,\
|
||||||
uint i [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (i >= numel){ \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use half::{bf16, f16};
|
use half::f16;
|
||||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
|
|
||||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||||
@ -23,18 +23,13 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
|||||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
|
||||||
let b = 10f32.powi(digits);
|
|
||||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let mut output = new_buffer(&device, v);
|
||||||
call_unary_contiguous(
|
call_unary_contiguous(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -42,7 +37,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -58,7 +53,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
|||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let left = new_buffer(&device, x);
|
let left = new_buffer(&device, x);
|
||||||
let right = new_buffer(&device, y);
|
let right = new_buffer(&device, y);
|
||||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||||
call_binary_contiguous(
|
call_binary_contiguous(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -67,7 +62,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
|||||||
x.len(),
|
x.len(),
|
||||||
&left,
|
&left,
|
||||||
&right,
|
&right,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -86,7 +81,7 @@ fn run_strided<T: Clone>(
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let mut output = new_buffer(&device, v);
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
call_unary_strided(
|
call_unary_strided(
|
||||||
&device,
|
&device,
|
||||||
@ -97,7 +92,7 @@ fn run_strided<T: Clone>(
|
|||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
offset,
|
offset,
|
||||||
&output,
|
&mut output,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -225,9 +220,7 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let mut output = new_buffer(&device, v);
|
||||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
|
||||||
let output = device.new_buffer(size, options);
|
|
||||||
|
|
||||||
call_cast_contiguous(
|
call_cast_contiguous(
|
||||||
&device,
|
&device,
|
||||||
@ -236,8 +229,7 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
0,
|
&mut output,
|
||||||
&output,
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -253,17 +245,11 @@ fn cast_u32_f32() {
|
|||||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0];
|
|
||||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
|
||||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
|
||||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
|
||||||
|
|
||||||
let v = vec![1.0f32; 10_000];
|
let v = vec![1.0f32; 10_000];
|
||||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
assert_eq!(results.len(), 10_000);
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||||
@ -273,7 +259,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let mut output = new_buffer(&device, v);
|
||||||
|
|
||||||
let size = v.len();
|
let size = v.len();
|
||||||
|
|
||||||
@ -281,10 +267,9 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"affine_float",
|
|
||||||
size,
|
size,
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -295,42 +280,6 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
output.read_to_vec::<T>(v.len())
|
output.read_to_vec::<T>(v.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_affine_strided<T: Clone>(
|
|
||||||
v: &[T],
|
|
||||||
shape: &[usize],
|
|
||||||
strides: &[usize],
|
|
||||||
mul: f64,
|
|
||||||
add: f64,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
|
|
||||||
call_affine_strided(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
"affine_float_strided",
|
|
||||||
shape,
|
|
||||||
&input,
|
|
||||||
strides,
|
|
||||||
0,
|
|
||||||
&output,
|
|
||||||
mul as f32,
|
|
||||||
add as f32,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
let len: usize = shape.iter().product();
|
|
||||||
output.read_to_vec::<T>(len)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn affine() {
|
fn affine() {
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
@ -346,18 +295,6 @@ fn affine() {
|
|||||||
assert_eq!(result, vec![2.6; 40_000]);
|
assert_eq!(result, vec![2.6; 40_000]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn affine_strided() {
|
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
||||||
let mul = 1.5;
|
|
||||||
let add = 1.1;
|
|
||||||
let shape = [4];
|
|
||||||
let strides = [2];
|
|
||||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
|
||||||
// 1 on 2
|
|
||||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn index_select() {
|
fn index_select() {
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
@ -376,26 +313,7 @@ fn index_select() {
|
|||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_select_f16() {
|
|
||||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
||||||
.into_iter()
|
|
||||||
.map(|x| f16::from_f32(x))
|
|
||||||
.collect();
|
|
||||||
let shape = [5, 2];
|
|
||||||
let ids = [0u32, 4, 2];
|
|
||||||
let dim = 0;
|
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
|
||||||
assert_eq!(
|
|
||||||
approx_f16(result, 4),
|
|
||||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_select_dim1() {
|
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
let ids = [0u32, 1, 0];
|
let ids = [0u32, 1, 0];
|
||||||
@ -403,7 +321,7 @@ fn index_select_dim1() {
|
|||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -423,26 +341,20 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
let dst_el = ids.len() * left_size * right_size;
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
let name = match core::mem::size_of::<T>() {
|
|
||||||
4 => "is_u32_f32",
|
|
||||||
2 => "is_u32_f16",
|
|
||||||
_ => unimplemented!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
call_index_select(
|
call_index_select(
|
||||||
&device,
|
&device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
"is_u32_f32",
|
||||||
shape,
|
shape,
|
||||||
ids.len(),
|
ids.len(),
|
||||||
dim,
|
dim,
|
||||||
&embeddings_buffer,
|
&embeddings_buffer,
|
||||||
&ids_buffer,
|
&ids_buffer,
|
||||||
&dst_buffer,
|
&mut dst_buffer,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -539,7 +451,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
call_reduce_contiguous(
|
call_reduce_contiguous(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -548,8 +460,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
v.len(),
|
v.len(),
|
||||||
out_length,
|
out_length,
|
||||||
&input,
|
&input,
|
||||||
0,
|
&mut output,
|
||||||
&output,
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -564,7 +475,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let mut output = new_buffer(&device, v);
|
||||||
call_last_softmax(
|
call_last_softmax(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -573,7 +484,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||||||
v.len(),
|
v.len(),
|
||||||
last_dim,
|
last_dim,
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -625,28 +536,6 @@ fn softmax() {
|
|||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||||
);
|
);
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| f16::from_f32(*v))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
|
||||||
assert_eq!(
|
|
||||||
approx_f16(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
|
||||||
);
|
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| bf16::from_f32(*v))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
|
||||||
assert_eq!(
|
|
||||||
approx_bf16(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_where_cond<I: Clone, T: Clone>(
|
fn run_where_cond<I: Clone, T: Clone>(
|
||||||
@ -682,7 +571,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
|||||||
options,
|
options,
|
||||||
);
|
);
|
||||||
|
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||||
call_where_cond_strided(
|
call_where_cond_strided(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -695,7 +584,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
|||||||
(&left_stride, left_offset),
|
(&left_stride, left_offset),
|
||||||
&right,
|
&right,
|
||||||
(&cond_stride, cond_offset),
|
(&cond_stride, cond_offset),
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_math>
|
|
||||||
#
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -20,39 +17,10 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
template <typename T> METAL_FUNC T erf(T in){
|
|
||||||
float x = (float) in;
|
|
||||||
// constants
|
|
||||||
float a1 = 0.254829592;
|
|
||||||
float a2 = -0.284496736;
|
|
||||||
float a3 = 1.421413741;
|
|
||||||
float a4 = -1.453152027;
|
|
||||||
float a5 = 1.061405429;
|
|
||||||
float p = 0.3275911;
|
|
||||||
|
|
||||||
// Save the sign of x
|
|
||||||
int sign = 1;
|
|
||||||
if (x < 0)
|
|
||||||
sign = -1;
|
|
||||||
x = fabs(x);
|
|
||||||
|
|
||||||
// A&S formula 7.1.26
|
|
||||||
float t = 1.0/(1.0 + p*x);
|
|
||||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
|
||||||
|
|
||||||
return T(sign*y);
|
|
||||||
}
|
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
|
||||||
template <typename T> METAL_FUNC T gelu(T x){
|
|
||||||
T x_sq = x * x;
|
|
||||||
T x_cube = x_sq * x;
|
|
||||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
|
||||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
|
||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
@ -96,16 +64,8 @@ UNARY_OP(sqrt)
|
|||||||
UNARY_OP(neg)
|
UNARY_OP(neg)
|
||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
UNARY_OP(log)
|
UNARY_OP(log)
|
||||||
UNARY_OP(gelu)
|
|
||||||
UNARY_OP(ceil)
|
|
||||||
UNARY_OP(floor)
|
|
||||||
UNARY_OP(round)
|
|
||||||
UNARY_OP(gelu_erf)
|
|
||||||
UNARY_OP(erf)
|
|
||||||
UNARY(id, float, copy_float, copy_float_strided)
|
UNARY(id, float, copy_float, copy_float_strided)
|
||||||
UNARY(id, half, copy_half, copy_half_strided)
|
UNARY(id, half, copy_half, copy_half_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
|
||||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_UNARY_OP(cos)
|
BFLOAT_UNARY_OP(cos)
|
||||||
@ -115,12 +75,6 @@ BFLOAT_UNARY_OP(sqrt)
|
|||||||
BFLOAT_UNARY_OP(neg)
|
BFLOAT_UNARY_OP(neg)
|
||||||
BFLOAT_UNARY_OP(exp)
|
BFLOAT_UNARY_OP(exp)
|
||||||
BFLOAT_UNARY_OP(log)
|
BFLOAT_UNARY_OP(log)
|
||||||
BFLOAT_UNARY_OP(gelu)
|
|
||||||
BFLOAT_UNARY_OP(ceil)
|
|
||||||
BFLOAT_UNARY_OP(floor)
|
|
||||||
BFLOAT_UNARY_OP(round)
|
|
||||||
BFLOAT_UNARY_OP(gelu_erf)
|
|
||||||
BFLOAT_UNARY_OP(erf)
|
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||||
#endif
|
#endif
|
||||||
|
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -30,4 +29,3 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
|
||||||
|
@ -201,46 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &candle::MetalStorage,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
|
||||||
use candle::{backend::BackendStorage, DType};
|
|
||||||
let device = storage.device();
|
|
||||||
let command_buffer = device.command_buffer();
|
|
||||||
let kernels = device.kernels();
|
|
||||||
let name = match storage.dtype() {
|
|
||||||
DType::F32 => "softmax_float",
|
|
||||||
DType::F16 => "softmax_half",
|
|
||||||
DType::BF16 => "softmax_bfloat",
|
|
||||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let n = layout.stride().len();
|
|
||||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
|
||||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
|
||||||
candle_metal_kernels::call_last_softmax(
|
|
||||||
device.metal_device(),
|
|
||||||
&command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
elem_count,
|
|
||||||
last_dim,
|
|
||||||
storage.buffer(),
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
|
||||||
Ok((newstorage, layout.shape().clone()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user