mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
1f23cea90c | |||
ce33d6ad2a | |||
3d0ade406a | |||
2ca086939f | |||
4349ff1fc2 |
@ -5,43 +5,13 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{CpuStorage, Device, Layout, Shape, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
struct ArgSort;
|
||||
impl candle::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"arg-sort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &CpuStorage,
|
||||
layout: &Layout,
|
||||
) -> candle::Result<(CpuStorage, Shape)> {
|
||||
if layout.shape().rank() != 1 {
|
||||
candle::bail!(
|
||||
"input should have a single dimension, got {:?}",
|
||||
layout.shape()
|
||||
)
|
||||
}
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &slice[o1..o2],
|
||||
};
|
||||
let mut dst = (0..src.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| src[i as usize].total_cmp(&src[j as usize]));
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::new(&[0.0f32, 1.0, 3.0, 2.0, -12.0, 4.0, 3.5], &Device::Cpu)?;
|
||||
let indices = a.apply_op1(ArgSort)?;
|
||||
let a_sorted = a.gather(&indices, 0)?;
|
||||
println!("{indices}");
|
||||
println!("{a_sorted}");
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -36,7 +38,9 @@ impl From<String> for MetalError {
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -58,10 +62,48 @@ impl MetalDevice {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
|
||||
self.command_buffer.try_read().unwrap()
|
||||
}
|
||||
|
||||
pub fn commit(&self) {
|
||||
let mut old = self.command_buffer.try_write().unwrap();
|
||||
match old.status() {
|
||||
metal::MTLCommandBufferStatus::NotEnqueued
|
||||
| metal::MTLCommandBufferStatus::Enqueued => {
|
||||
old.commit();
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*old = command_buffer;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut old = self.command_buffer.try_write().unwrap();
|
||||
match old.status() {
|
||||
metal::MTLCommandBufferStatus::NotEnqueued
|
||||
| metal::MTLCommandBufferStatus::Enqueued => {
|
||||
old.commit();
|
||||
old.wait_until_completed();
|
||||
}
|
||||
metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => {
|
||||
old.wait_until_completed();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*old = command_buffer;
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
@ -70,16 +112,107 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
|
||||
}
|
||||
|
||||
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> {
|
||||
let mut buffers = self.buffers.try_write().unwrap();
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return sub.clone();
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
new_buffer
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate);
|
||||
{
|
||||
let command = self.command_buffer();
|
||||
let blit = command.new_blit_command_encoder();
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.end_encoding();
|
||||
}
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed();
|
||||
real
|
||||
}
|
||||
|
||||
pub fn new_matrix(
|
||||
&self,
|
||||
(b, m, n): (NSUInteger, NSUInteger, NSUInteger),
|
||||
size: NSUInteger,
|
||||
type_id: u32,
|
||||
dtype: DType,
|
||||
) -> Result<(Matrix, Arc<Buffer>)> {
|
||||
let elem_count = (b * m * n) as usize;
|
||||
let out_buffer = self.new_buffer(elem_count, dtype);
|
||||
|
||||
let result_descriptor =
|
||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
Ok((result_matrix, out_buffer))
|
||||
}
|
||||
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(&self);
|
||||
descriptor.set_output_url(path);
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: metal::Buffer,
|
||||
buffer: Arc<metal::Buffer>,
|
||||
matrices: Arc<
|
||||
RwLock<
|
||||
HashMap<
|
||||
(
|
||||
NSUInteger,
|
||||
NSUInteger,
|
||||
NSUInteger,
|
||||
bool,
|
||||
NSUInteger,
|
||||
NSUInteger,
|
||||
u32,
|
||||
),
|
||||
Matrix,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage {
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
drop(command_buffer);
|
||||
self.device.wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
|
||||
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
|
||||
DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
|
||||
DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 {
|
||||
crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
|
||||
let buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
});
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
if !(sum_dims.len() == 1
|
||||
&& sum_dims[0] == layout.shape().rank() - 1
|
||||
&& layout.is_contiguous()
|
||||
&& layout.start_offset() == 0)
|
||||
&& layout.stride()[sum_dims[0]] == 1)
|
||||
{
|
||||
crate::bail!("Non contiguous reduce op not supported yet");
|
||||
crate::bail!("Non last dim reduce op not supported yet");
|
||||
}
|
||||
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
if dtype == DType::U32 {
|
||||
crate::bail!("Implement return index reduce op");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage {
|
||||
src_el,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_contiguous(
|
||||
@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage {
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
crate::bail!(
|
||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||
self.dtype,
|
||||
dtype
|
||||
);
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage {
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage {
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
0,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
command_buffer.set_label("unary");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
{
|
||||
@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage {
|
||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
("badd", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("bsub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("bmul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
("bdiv", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage {
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage {
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage {
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
&lhs_l.stride(),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
command_buffer.set_label("binary");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let mut buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if t.dtype() != f.dtype() {
|
||||
crate::bail!("Invalid ternary different dtypes for values");
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
"where_u8_f32",
|
||||
&dims,
|
||||
name,
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
layout.stride(),
|
||||
@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage {
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&mut buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage {
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&mut buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -561,156 +795,132 @@ impl BackendStorage for MetalStorage {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
use metal::mps::matrix::*;
|
||||
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
||||
let size = core::mem::size_of::<f32>() as NSUInteger;
|
||||
|
||||
let elem_count = b * m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype);
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
||||
}
|
||||
};
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
let n = n as NSUInteger;
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_descriptor = if transpose_left {
|
||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||
};
|
||||
let right_descriptor = if transpose_right {
|
||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||
};
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
matrix_multiplication.set_batch_size(b);
|
||||
|
||||
// Encode kernel to command buffer
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
let command_buffer = self.device.command_buffer();
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
&self.buffer,
|
||||
&src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// Create kernel
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(
|
||||
&self.buffer,
|
||||
src_offset,
|
||||
dst.buffer(),
|
||||
dst_offset,
|
||||
self.buffer.length() - src_offset,
|
||||
);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
||||
let matrices = Arc::new(RwLock::new(HashMap::new()));
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
matrices,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
fn matrix(
|
||||
&self,
|
||||
(b, m, n): (NSUInteger, NSUInteger, NSUInteger),
|
||||
transpose: bool,
|
||||
size: NSUInteger,
|
||||
offset: NSUInteger,
|
||||
type_id: u32,
|
||||
) -> Result<Matrix> {
|
||||
let key = (b, m, n, transpose, size, offset, type_id);
|
||||
|
||||
let mut matrices = self.matrices.try_write().unwrap();
|
||||
if let Some(matrix) = matrices.get(&key) {
|
||||
Ok(matrix.clone())
|
||||
} else {
|
||||
let descriptor = if transpose {
|
||||
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
|
||||
};
|
||||
let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
matrices.insert(key, matrix.clone());
|
||||
Ok(matrix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -720,10 +930,14 @@ impl BackendDevice for MetalDevice {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
@ -743,9 +957,8 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
@ -755,49 +968,20 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
Ok(Self::Storage::new(
|
||||
buffer.into(),
|
||||
self.clone(),
|
||||
storage.dtype(),
|
||||
))
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
|
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
|
@ -33,6 +33,24 @@ kernel void FN_NAME( \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = gid / right_size / left_size; \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid % left_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
@ -75,6 +75,7 @@ kernel void FN_NAME( \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
|
@ -1,6 +1,6 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
@ -59,8 +60,8 @@ impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
(core::mem::size_of::<T>() * data.len()) as u64,
|
||||
data.as_ptr() as *const T as *const c_void,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -105,19 +106,14 @@ pub enum Source {
|
||||
Ternary,
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -126,16 +122,18 @@ macro_rules! ops{
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
||||
pub const HALF: Kernel = Kernel("copy_half");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||
pub const U32: Kernel = Kernel("copy_u32");
|
||||
pub const U8: Kernel = Kernel("copy_u8");
|
||||
}
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -144,12 +142,20 @@ macro_rules! ops{
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
@ -161,8 +167,12 @@ pub enum MetalKernelError {
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0}")]
|
||||
#[error("Error while loading function: {0:?}")]
|
||||
LoadFunctionError(String),
|
||||
#[error("Failed to create compute function")]
|
||||
FailedToCreateComputeFunction,
|
||||
#[error("Failed to create pipeline")]
|
||||
FailedToCreatePipeline(String),
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
@ -171,21 +181,103 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
}
|
||||
}
|
||||
|
||||
type KernelMap<T> = HashMap<&'static str, T>;
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Value {
|
||||
USize(usize),
|
||||
Bool(bool),
|
||||
F32(f32),
|
||||
U16(u16),
|
||||
}
|
||||
|
||||
impl std::hash::Hash for Value {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Value::F32(v) => v.to_bits().hash(state),
|
||||
Value::USize(v) => v.hash(state),
|
||||
Value::U16(v) => v.hash(state),
|
||||
Value::Bool(v) => v.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
fn data_type(&self) -> MTLDataType {
|
||||
match self {
|
||||
Value::USize(_) => MTLDataType::UInt,
|
||||
Value::F32(_) => MTLDataType::Float,
|
||||
Value::U16(_) => MTLDataType::UShort,
|
||||
Value::Bool(_) => MTLDataType::Bool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Not true, good enough for our purposes.
|
||||
impl Eq for Value {}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Hash)]
|
||||
struct ConstantValues(Vec<(usize, Value)>);
|
||||
|
||||
impl ConstantValues {
|
||||
pub fn new(values: Vec<(usize, Value)>) -> Self {
|
||||
Self(values)
|
||||
}
|
||||
|
||||
fn function_constant_values(&self) -> FunctionConstantValues {
|
||||
let f = FunctionConstantValues::new();
|
||||
for (index, value) in &self.0 {
|
||||
let ty = value.data_type();
|
||||
match value {
|
||||
Value::USize(v) => {
|
||||
f.set_constant_value_at_index(
|
||||
v as *const usize as *const c_void,
|
||||
ty,
|
||||
*index as u64,
|
||||
);
|
||||
}
|
||||
Value::F32(v) => {
|
||||
f.set_constant_value_at_index(
|
||||
v as *const f32 as *const c_void,
|
||||
ty,
|
||||
*index as u64,
|
||||
);
|
||||
}
|
||||
Value::U16(v) => {
|
||||
f.set_constant_value_at_index(
|
||||
v as *const u16 as *const c_void,
|
||||
ty,
|
||||
*index as u64,
|
||||
);
|
||||
}
|
||||
Value::Bool(v) => {
|
||||
f.set_constant_value_at_index(
|
||||
v as *const bool as *const c_void,
|
||||
ty,
|
||||
*index as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Functions = KernelMap<Function>;
|
||||
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
funcs: RwLock<Functions>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let funcs = RwLock::new(Functions::new());
|
||||
Self { libraries, funcs }
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
@ -197,6 +289,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Mfa => unimplemented!("Mfa is not a source"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,33 +302,75 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device
|
||||
.new_library_with_data(source_data)
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_function(
|
||||
fn load_function(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
constants: Option<FunctionConstantValues>,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let mut funcs = self.funcs.write()?;
|
||||
if let Some(func) = funcs.get(name) {
|
||||
Ok(func.clone())
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, constants)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
Ok(func)
|
||||
}
|
||||
|
||||
fn load_pipeline_with_constants(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
constants: Option<ConstantValues>,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
let mut pipelines = self.pipelines.write()?;
|
||||
let key = (name, constants);
|
||||
if let Some(pipeline) = pipelines.get(&key) {
|
||||
Ok(pipeline.clone())
|
||||
} else {
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
funcs.insert(name, func.clone());
|
||||
Ok(func)
|
||||
let (name, constants) = key;
|
||||
let func = self.load_function(
|
||||
device,
|
||||
source,
|
||||
name,
|
||||
constants.as_ref().map(|c| c.function_constant_values()),
|
||||
)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
||||
pipelines.insert((name, constants), pipeline.clone());
|
||||
|
||||
Ok(pipeline)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_pipeline(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
self.load_pipeline_with_constants(device, source, name, None)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -246,18 +381,9 @@ pub fn call_unary_contiguous(
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -279,18 +405,10 @@ pub fn call_unary_strided(
|
||||
input: &Buffer,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -326,17 +444,9 @@ pub fn call_binary_contiguous(
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -363,17 +473,9 @@ pub fn call_binary_strided(
|
||||
right_input: &Buffer,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -411,22 +513,52 @@ pub fn call_cast_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_cast_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
@ -435,7 +567,6 @@ pub fn call_cast_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_reduce_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -444,24 +575,19 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -495,18 +621,9 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -542,21 +659,14 @@ pub fn call_affine(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -570,6 +680,45 @@ pub fn call_affine(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
add,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_where_cond_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -582,17 +731,9 @@ pub fn call_where_cond_strided(
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Ternary, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -634,17 +775,14 @@ pub fn call_index_select(
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
ids: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids_size * left_size * right_size;
|
||||
|
||||
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
@ -671,5 +809,169 @@ pub fn call_index_select(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_gemm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
todo!();
|
||||
// Err(MetalError::MatMulNonContiguous {
|
||||
// lhs_stride: lhs_stride.to_vec(),
|
||||
// rhs_stride: rhs_stride.to_vec(),
|
||||
// mnk: (m, n, k),
|
||||
// })?
|
||||
};
|
||||
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
todo!();
|
||||
// Err(MetalError::MatMulNonContiguous {
|
||||
// lhs_stride: lhs_stride.to_vec(),
|
||||
// rhs_stride: rhs_stride.to_vec(),
|
||||
// mnk: (m, n, k),
|
||||
// })?
|
||||
};
|
||||
let d_trans = false;
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let m_simd = 16;
|
||||
let n_simd = 16;
|
||||
let k_simd = 16;
|
||||
let m_splits = 2;
|
||||
let n_splits = 2;
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
// println!("Constants {constants:?}");
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
|
||||
let mut block_elements = a_block_length + b_block_length;
|
||||
if (m % 8 != 0) && (n % 8 != 0) {
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
} else {
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
// TODO adapt for f16
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
// println!("Threadgroup {block_bytes}");
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
// TODO Tensor D
|
||||
|
||||
let grid_z = b;
|
||||
if batched {
|
||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_c = m * n * bytes as usize;
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||
for i in 0..b {
|
||||
buffer.push((i * byte_stride_a) as u64);
|
||||
buffer.push((i * byte_stride_b) as u64);
|
||||
buffer.push((i * byte_stride_c) as u64);
|
||||
buffer.push((i * byte_stride_d) as u64);
|
||||
}
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
buffer.len() as NSUInteger * core::mem::size_of::<u64>(),
|
||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
);
|
||||
}
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: divide(n, n_group.into()),
|
||||
height: divide(m, m_group.into()),
|
||||
depth: grid_z as NSUInteger,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
Binary file not shown.
@ -1,6 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
@ -45,10 +47,10 @@ kernel void NAME( \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -56,10 +58,10 @@ kernel void NAME( \
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -68,72 +70,74 @@ kernel void NAME( \
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
while (idx < stop_idx) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const T val = T(exp(src[idx] - _max)); \
|
||||
dst[idx] = val; \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
#endif
|
||||
|
@ -32,6 +32,9 @@ kernel void FN_NAME( \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
|
211
candle-metal-kernels/src/test.swift
Normal file
211
candle-metal-kernels/src/test.swift
Normal file
@ -0,0 +1,211 @@
|
||||
|
||||
import Metal
|
||||
import MetalPerformanceShadersGraph
|
||||
|
||||
|
||||
|
||||
let type = MTLDataType.float;
|
||||
let dataType = type;
|
||||
var B = 2;
|
||||
var M = 2;
|
||||
var N = 4;
|
||||
var K = 3;
|
||||
var A_trans = false;
|
||||
var B_trans = false;
|
||||
var D_trans = false;
|
||||
var alpha = Float(1.0);
|
||||
var beta = Float(0.0);
|
||||
var batched = B > 1;
|
||||
var fused_activation = false;
|
||||
var fused_bias = false;
|
||||
let constants = MTLFunctionConstantValues()
|
||||
constants.setConstantValue(&M, type: .uint, index: 0)
|
||||
constants.setConstantValue(&N, type: .uint, index: 1)
|
||||
constants.setConstantValue(&K, type: .uint, index: 2)
|
||||
constants.setConstantValue(&A_trans, type: .bool, index: 10)
|
||||
constants.setConstantValue(&B_trans, type: .bool, index: 11)
|
||||
constants.setConstantValue(&D_trans, type: .bool, index: 13)
|
||||
constants.setConstantValue(&alpha, type: .float, index: 20)
|
||||
constants.setConstantValue(&beta, type: .float, index: 21)
|
||||
constants.setConstantValue(&batched, type: .bool, index: 100)
|
||||
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
|
||||
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
|
||||
|
||||
|
||||
var M_simd = UInt16(16)
|
||||
var N_simd = UInt16(16)
|
||||
var K_simd = UInt16(32)
|
||||
var M_splits = UInt16(2)
|
||||
var N_splits = UInt16(2)
|
||||
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
|
||||
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
|
||||
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
|
||||
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
|
||||
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
|
||||
|
||||
let M_group = M_simd * M_splits
|
||||
let N_group = N_simd * N_splits
|
||||
|
||||
// Satisfy Metal API validation.
|
||||
#if DEBUG
|
||||
do {
|
||||
var garbage: SIMD4<UInt64> = .zero
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 102)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 103)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 113)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 50000)
|
||||
}
|
||||
#endif
|
||||
print(constants)
|
||||
|
||||
let device = MTLCopyAllDevices().first!
|
||||
device.shouldMaximizeConcurrentCompilation = true
|
||||
|
||||
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
|
||||
libraryURL.append(component: "src")
|
||||
libraryURL.append(component: "libMetalFlashAttention.metallib")
|
||||
let library = try! device.makeLibrary(URL: libraryURL)
|
||||
|
||||
var name: String
|
||||
switch dataType {
|
||||
case .half: name = "hgemm"
|
||||
case .float: name = "sgemm"
|
||||
default: fatalError()
|
||||
}
|
||||
let function = try! library.makeFunction(
|
||||
name: name, constantValues: constants)
|
||||
|
||||
let A_block_length = M_group * K_simd
|
||||
let B_block_length = K_simd * N_group
|
||||
|
||||
var blockElements = A_block_length + B_block_length;
|
||||
if (M % 8 != 0) && (N % 8 != 0) {
|
||||
let C_block_length = M_group * N_group;
|
||||
blockElements = max(C_block_length, blockElements)
|
||||
}
|
||||
if fused_bias {
|
||||
if D_trans {
|
||||
blockElements = max(blockElements, M_group)
|
||||
} else {
|
||||
blockElements = max(blockElements, N_group)
|
||||
}
|
||||
}
|
||||
// let blockBytes = blockElements * UInt16(dataType.size)
|
||||
let elementSize = 4
|
||||
let blockBytes = blockElements * UInt16(elementSize)
|
||||
|
||||
func ceilDivide(target: Int, granularity: UInt16) -> Int {
|
||||
(target + Int(granularity) - 1) / Int(granularity)
|
||||
}
|
||||
var gridSize = MTLSize(
|
||||
width: ceilDivide(target: N, granularity: N_group),
|
||||
height: ceilDivide(target: M, granularity: M_group),
|
||||
depth: 1)
|
||||
let groupSize = MTLSize(
|
||||
width: Int(32 * M_splits * N_splits),
|
||||
height: 1,
|
||||
depth: 1)
|
||||
|
||||
let commandQueue = device.makeCommandQueue()!
|
||||
let commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
|
||||
let pipeline = try device.makeComputePipelineState(function: function)
|
||||
|
||||
let threadgroupMemoryLength = blockBytes;
|
||||
print(threadgroupMemoryLength)
|
||||
encoder.setComputePipelineState(pipeline)
|
||||
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
|
||||
|
||||
|
||||
let rowsA = M;
|
||||
let columnsA = K;
|
||||
let rowsB = K;
|
||||
let columnsB = N;
|
||||
let rowsC = M;
|
||||
let columnsC = N;
|
||||
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
|
||||
|
||||
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
|
||||
|
||||
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
for i in 0..<arrayA.count {
|
||||
arrayA[i] = Float(i)
|
||||
}
|
||||
|
||||
for i in 0..<arrayB.count {
|
||||
arrayB[i] = Float(i)
|
||||
}
|
||||
|
||||
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
print(arrayA)
|
||||
print(arrayB)
|
||||
|
||||
|
||||
encoder.setBuffer(bufferA, offset: 0, index: 0)
|
||||
encoder.setBuffer(bufferB, offset: 0, index: 1)
|
||||
encoder.setBuffer(bufferC, offset: 0, index: 2)
|
||||
var gridZ: Int = B
|
||||
if batched{
|
||||
func byteStride(shape: [Int]) -> Int {
|
||||
let rank = shape.count
|
||||
var output = elementSize * shape[rank - 2] * shape[rank - 1]
|
||||
if shape.dropLast(2).reduce(1, *) == 1 {
|
||||
output = 0
|
||||
}
|
||||
return output
|
||||
}
|
||||
let byteStrideA = M*K*elementSize
|
||||
let byteStrideB = N*K*elementSize
|
||||
let byteStrideC = M*N*elementSize
|
||||
|
||||
let byteStrideD = 0
|
||||
// if let shapeD = tensors.d?.shape {
|
||||
// let rank = shapeD.count
|
||||
// byteStrideD = elementSize * shapeD[rank - 1]
|
||||
// if shapeD.dropLast(1).reduce(1, *) == 1 {
|
||||
// byteStrideD = 0
|
||||
// }
|
||||
// }
|
||||
withUnsafeTemporaryAllocation(
|
||||
of: SIMD4<UInt64>.self, capacity: gridZ
|
||||
) { buffer in
|
||||
for i in 0..<buffer.count {
|
||||
buffer[i] = SIMD4(
|
||||
UInt64(truncatingIfNeeded: i * byteStrideA),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideB),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideC),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||
}
|
||||
|
||||
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
|
||||
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
|
||||
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||
print("BATCHED")
|
||||
print(buffer)
|
||||
}
|
||||
}
|
||||
gridSize.depth = gridZ
|
||||
|
||||
|
||||
print(gridSize, groupSize)
|
||||
encoder.dispatchThreadgroups(
|
||||
gridSize, threadsPerThreadgroup: groupSize
|
||||
)
|
||||
encoder.endEncoding()
|
||||
commandBuffer.commit()
|
||||
|
||||
commandBuffer.waitUntilCompleted()
|
||||
var contents = bufferC!.contents();
|
||||
|
||||
var count = B * rowsA * columnsB;
|
||||
|
||||
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
|
||||
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
|
||||
print(Array(bufferedPointer))
|
@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
@ -23,13 +23,18 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -37,7 +42,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -53,7 +58,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = new_buffer(&device, x);
|
||||
let right = new_buffer(&device, y);
|
||||
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -62,7 +67,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -81,7 +86,7 @@ fn run_strided<T: Clone>(
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
@ -92,7 +97,7 @@ fn run_strided<T: Clone>(
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&mut output,
|
||||
&output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
@ -220,7 +225,9 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
||||
let output = device.new_buffer(size, options);
|
||||
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
@ -229,7 +236,8 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -245,11 +253,17 @@ fn cast_u32_f32() {
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
@ -259,7 +273,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
@ -267,9 +281,10 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -280,6 +295,42 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_affine_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float_strided",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let len: usize = shape.iter().product();
|
||||
output.read_to_vec::<T>(len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
@ -295,6 +346,18 @@ fn affine() {
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine_strided() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let shape = [4];
|
||||
let strides = [2];
|
||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||
// 1 on 2
|
||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
@ -313,7 +376,26 @@ fn index_select() {
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
@ -321,7 +403,7 @@ fn index_select() {
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
);
|
||||
}
|
||||
|
||||
@ -341,20 +423,26 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
"is_u32_f32",
|
||||
name,
|
||||
shape,
|
||||
ids.len(),
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&mut dst_buffer,
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -451,7 +539,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -460,7 +548,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
&mut output,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -475,7 +564,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_last_softmax(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -484,7 +573,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -536,6 +625,28 @@ fn softmax() {
|
||||
approx(results, 4),
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
@ -571,7 +682,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
options,
|
||||
);
|
||||
|
||||
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -584,7 +695,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -614,3 +725,76 @@ fn where_cond() {
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs: &[T],
|
||||
rhs_stride: Vec<usize>,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
0,
|
||||
&lhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
}
|
||||
|
@ -1,4 +1,7 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||
template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -64,8 +96,16 @@ UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -29,3 +30,4 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
|
@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user