mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
03641293ee | |||
064ba17bd7 | |||
e8ee253ee0 | |||
8bd3d6b94b | |||
6a3ca7da0c | |||
586b6f6fff | |||
e4b0cc59f5 | |||
0a6e0a8c9a | |||
972903021c | |||
6bc92e63cb | |||
aa04015098 | |||
8b5059e951 | |||
26540641c1 | |||
34d83377f6 | |||
77197379cc | |||
916a8c5464 | |||
243e83f2b9 |
@ -201,10 +201,9 @@ impl Device {
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
crate::bail!("Metal rand_uniform not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,26 @@ use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LockError {
|
||||
#[error("{0}")]
|
||||
Poisoned(String),
|
||||
#[error("Would block")]
|
||||
WouldBlock,
|
||||
}
|
||||
|
||||
impl<T> From<TryLockError<T>> for MetalError {
|
||||
fn from(value: TryLockError<T>) -> Self {
|
||||
match value {
|
||||
TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
|
||||
TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -24,6 +43,14 @@ pub enum MetalError {
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
#[error("{0:?}")]
|
||||
LockError(LockError),
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
@ -32,15 +59,53 @@ impl From<String> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
||||
/// execution order to be linear.
|
||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||
/// compute graph.
|
||||
fence: metal::Fence,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -70,23 +135,25 @@ impl MetalDevice {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> CommandBuffer {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let mut command_buffer = command_buffers[0].to_owned();
|
||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||
if *index > 20 {
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffers = vec![command_buffer.clone()];
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
}
|
||||
*index += 1;
|
||||
command_buffer
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let command_buffer = &command_buffers[0];
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -97,7 +164,8 @@ impl MetalDevice {
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -108,57 +176,49 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
fn _new_buffer(
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Arc<Buffer> {
|
||||
let mut buffers = self.buffers.try_write().unwrap();
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return sub.clone();
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(s) > 1)
|
||||
.map(|s| Arc::clone(s))
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
|
||||
new_buffer
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// This method will block the computation because of the
|
||||
/// lack of lifetime management through the GPU.
|
||||
/// Internal comment for technical details.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
let real = self.allocate_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
);
|
||||
let command_buffer = self.command_buffer();
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
@ -174,15 +234,45 @@ impl MetalDevice {
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed();
|
||||
real
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return Ok(sub.clone());
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(&self);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
|
||||
capture
|
||||
@ -194,8 +284,11 @@ impl MetalDevice {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
/// The actual buffer containing the data.
|
||||
buffer: Arc<metal::Buffer>,
|
||||
/// a reference to the device owning this buffer
|
||||
device: MetalDevice,
|
||||
/// The dtype is kept since buffers are untyped.
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
@ -223,9 +316,9 @@ impl BackendStorage for MetalStorage {
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
@ -234,7 +327,7 @@ impl BackendStorage for MetalStorage {
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
self.device.wait_until_completed()?;
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
@ -254,12 +347,12 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -276,8 +369,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -305,12 +398,12 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float",
|
||||
DType::F16 => "powf_half",
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -326,8 +419,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float_strided",
|
||||
DType::F16 => "powf_half_strided",
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -354,12 +447,12 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float",
|
||||
DType::F16 => "elu_half",
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -375,8 +468,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float_strided",
|
||||
DType::F16 => "elu_half_strided",
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -397,17 +490,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
if !(sum_dims.len() == 1
|
||||
&& sum_dims[0] == layout.shape().rank() - 1
|
||||
&& layout.stride()[sum_dims[0]] == 1)
|
||||
{
|
||||
crate::bail!("Non last dim reduce op not supported yet");
|
||||
}
|
||||
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
@ -427,28 +512,41 @@ impl BackendStorage for MetalStorage {
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||
_ => crate::bail!("Reduce op for non float"),
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
if dtype == DType::U32 {
|
||||
crate::bail!("Implement return index reduce op");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
src_el,
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
@ -459,21 +557,30 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
crate::bail!("cmp metal")
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
let name = match op {
|
||||
CmpOp::Eq => "eq",
|
||||
CmpOp::Ne => "ne",
|
||||
CmpOp::Le => "le",
|
||||
CmpOp::Ge => "ge",
|
||||
CmpOp::Lt => "lt",
|
||||
CmpOp::Gt => "gt",
|
||||
};
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||
@ -494,6 +601,7 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||
@ -520,8 +628,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
@ -621,72 +729,7 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
||||
let command_buffer = device.command_buffer();
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &B::KERNEL[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
@ -702,15 +745,19 @@ impl BackendStorage for MetalStorage {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let buffer = self.device.new_buffer(el, dtype, "where");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if t.dtype() != f.dtype() {
|
||||
crate::bail!("Invalid ternary different dtypes for values");
|
||||
crate::bail!(
|
||||
"Invalid where: different dtypes for values {:?} != {:?}",
|
||||
t.dtype(),
|
||||
f.dtype()
|
||||
);
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
|
||||
(left, right) => crate::bail!("where {left:?} - {right:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
@ -789,20 +836,84 @@ impl BackendStorage for MetalStorage {
|
||||
crate::bail!("upsample_nearest2d metal")
|
||||
}
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
crate::bail!("gather metal")
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_o1 * ids.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("scatter_add metal")
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "sa_u32_f32",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -819,13 +930,13 @@ impl BackendStorage for MetalStorage {
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select");
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -844,14 +955,49 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("index_add metal")
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "ia_u32_f32",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_index_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
ids_l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
}
|
||||
fn matmul(
|
||||
&self,
|
||||
@ -860,7 +1006,7 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
@ -869,7 +1015,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
@ -877,10 +1023,10 @@ impl BackendStorage for MetalStorage {
|
||||
&self.device.kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_l.stride(),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
@ -890,7 +1036,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
@ -945,6 +1091,111 @@ impl MetalStorage {
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn binary(
|
||||
&self,
|
||||
op: &'static str,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
|
||||
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
|
||||
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
|
||||
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
|
||||
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
|
||||
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
|
||||
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
|
||||
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
|
||||
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
|
||||
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
|
||||
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
|
||||
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
|
||||
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (strided::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
};
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -952,29 +1203,25 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 1;
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let command_buffers = (0..n)
|
||||
.map(|i| {
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer
|
||||
})
|
||||
.collect();
|
||||
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 20,
|
||||
};
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffers,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
})
|
||||
@ -995,8 +1242,8 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros");
|
||||
let command_buffer = self.command_buffer();
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
@ -1028,12 +1275,8 @@ impl BackendDevice for MetalDevice {
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
Ok(Self::Storage::new(
|
||||
buffer.into(),
|
||||
self.clone(),
|
||||
storage.dtype(),
|
||||
))
|
||||
}?;
|
||||
Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
|
@ -1863,10 +1863,7 @@ impl Tensor {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
// println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
|
@ -900,9 +900,7 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
let d = a.matmul(&c)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
|
@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \
|
||||
} \
|
||||
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
POWF(powf_float, float)
|
||||
POWF(powf_half, half)
|
||||
ELU(elu_float, float)
|
||||
ELU(elu_half, half)
|
||||
AFFINE(affine_f32, float)
|
||||
AFFINE(affine_f16, half)
|
||||
POWF(powf_f32, float)
|
||||
POWF(powf_f16, half)
|
||||
ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
POWF(powf_bfloat, bfloat);
|
||||
ELU(elu_bfloat, bfloat);
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
#endif
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -22,15 +25,15 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
device OUT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
TYPENAME x = left[tid]; \
|
||||
TYPENAME y = right[tid]; \
|
||||
output[tid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -40,33 +43,48 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
device OUT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
|
||||
output[tid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
BINARY_OP(MIN(x, y), min)
|
||||
BINARY_OP(MAX(x, y), max)
|
||||
|
||||
BINARY_OP_OUT(eq, x == y)
|
||||
BINARY_OP_OUT(ne, x != y)
|
||||
BINARY_OP_OUT(le, x <= y)
|
||||
BINARY_OP_OUT(lt, x < y)
|
||||
BINARY_OP_OUT(ge, x >= y)
|
||||
BINARY_OP_OUT(gt, x > y)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||
#endif
|
||||
|
@ -48,6 +48,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
|
@ -1,6 +1,34 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void index(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &ids_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t id_i = (tid / right_size) % ids_size;
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
/*
|
||||
// Force prevent out of bounds indexing
|
||||
// since there doesn't seem to be a good way to force crash
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
}
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -11,93 +39,160 @@ kernel void NAME( \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void gather(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &ids_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &ids_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) {
|
||||
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void index_add(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
constant size_t &ids_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const INDEX_TYPENAME idx = input_ids[j];
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
constant size_t &ids_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
|
||||
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
|
||||
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
|
||||
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
||||
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
|
||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
|
||||
|
@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
|
||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
@ -117,16 +125,16 @@ macro_rules! ops{
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
||||
pub const HALF: Kernel = Kernel("copy_half");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32");
|
||||
pub const HALF: Kernel = Kernel("copy_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
||||
pub const U32: Kernel = Kernel("copy_u32");
|
||||
pub const U8: Kernel = Kernel("copy_u8");
|
||||
}
|
||||
@ -137,16 +145,16 @@ macro_rules! ops{
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||
}
|
||||
@ -158,7 +166,7 @@ pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -220,6 +228,9 @@ impl Kernels {
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the give library from its [`source`].
|
||||
/// If this has been previously loaded it will just fetch it from cache.
|
||||
pub fn load_library(
|
||||
&self,
|
||||
device: &Device,
|
||||
@ -232,9 +243,11 @@ impl Kernels {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device
|
||||
.new_library_with_data(source_data)
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
@ -262,6 +275,9 @@ impl Kernels {
|
||||
Ok(func)
|
||||
}
|
||||
|
||||
/// Load the give pipeline
|
||||
/// loads the library from source, then gets the function [`name`] from
|
||||
/// that source
|
||||
fn load_pipeline_with_constants(
|
||||
&self,
|
||||
device: &Device,
|
||||
@ -290,6 +306,9 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the give pipeline
|
||||
/// loads the library from source, then gets the function [`name`] from
|
||||
/// that source (without constants)
|
||||
pub fn load_pipeline(
|
||||
&self,
|
||||
device: &Device,
|
||||
@ -569,6 +588,64 @@ pub fn call_reduce_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_reduce_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
shape.len(),
|
||||
shape,
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_last_softmax(
|
||||
device: &Device,
|
||||
@ -578,6 +655,7 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -585,7 +663,10 @@ pub fn call_last_softmax(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
@ -929,6 +1010,164 @@ pub fn call_index_select(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_gather(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids_size * left_size * right_size;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_scatter_add(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
let src_dim_size = src_shape[dim];
|
||||
let dst_el = left_size * right_size;
|
||||
let dst_dim_size = dst_shape[dim];
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_index_add(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
ids_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
let src_dim_size = src_shape[dim];
|
||||
let dst_el = left_size * right_size;
|
||||
let dst_dim_size = dst_shape[dim];
|
||||
let ids_dim_size = ids_shape[0];
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
ids_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Value {
|
||||
USize(usize),
|
||||
|
@ -2,6 +2,7 @@
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
@ -34,21 +156,21 @@ kernel void NAME( \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
@ -71,10 +193,6 @@ kernel void NAME( \
|
||||
} \
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
@ -142,8 +260,33 @@ kernel void NAME(
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -1,209 +0,0 @@
|
||||
|
||||
import Metal
|
||||
import MetalPerformanceShadersGraph
|
||||
|
||||
|
||||
|
||||
let type = MTLDataType.float;
|
||||
let dataType = type;
|
||||
var B = 2;
|
||||
var M = 2;
|
||||
var N = 2;
|
||||
var K = 2;
|
||||
var A_trans = false;
|
||||
var B_trans = false;
|
||||
var D_trans = false;
|
||||
var alpha = Float(1.0);
|
||||
var beta = Float(0.0);
|
||||
var batched = B > 1;
|
||||
var fused_activation = false;
|
||||
var fused_bias = false;
|
||||
let constants = MTLFunctionConstantValues()
|
||||
constants.setConstantValue(&M, type: .uint, index: 0)
|
||||
constants.setConstantValue(&N, type: .uint, index: 1)
|
||||
constants.setConstantValue(&K, type: .uint, index: 2)
|
||||
constants.setConstantValue(&A_trans, type: .bool, index: 10)
|
||||
constants.setConstantValue(&B_trans, type: .bool, index: 11)
|
||||
constants.setConstantValue(&D_trans, type: .bool, index: 13)
|
||||
constants.setConstantValue(&alpha, type: .float, index: 20)
|
||||
constants.setConstantValue(&beta, type: .float, index: 21)
|
||||
constants.setConstantValue(&batched, type: .bool, index: 100)
|
||||
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
|
||||
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
|
||||
|
||||
|
||||
var M_simd = UInt16(16)
|
||||
var N_simd = UInt16(16)
|
||||
var K_simd = UInt16(32)
|
||||
var M_splits = UInt16(2)
|
||||
var N_splits = UInt16(2)
|
||||
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
|
||||
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
|
||||
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
|
||||
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
|
||||
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
|
||||
|
||||
let M_group = M_simd * M_splits
|
||||
let N_group = N_simd * N_splits
|
||||
|
||||
// Satisfy Metal API validation.
|
||||
#if DEBUG
|
||||
do {
|
||||
var garbage: SIMD4<UInt64> = .zero
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 102)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 103)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 113)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 50000)
|
||||
}
|
||||
#endif
|
||||
|
||||
let device = MTLCopyAllDevices().first!
|
||||
device.shouldMaximizeConcurrentCompilation = true
|
||||
|
||||
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
|
||||
libraryURL.append(component: "src")
|
||||
libraryURL.append(component: "libMetalFlashAttention.metallib")
|
||||
let library = try! device.makeLibrary(URL: libraryURL)
|
||||
|
||||
var name: String
|
||||
switch dataType {
|
||||
case .half: name = "hgemm"
|
||||
case .float: name = "sgemm"
|
||||
default: fatalError()
|
||||
}
|
||||
let function = try! library.makeFunction(
|
||||
name: name, constantValues: constants)
|
||||
|
||||
let A_block_length = M_group * K_simd
|
||||
let B_block_length = K_simd * N_group
|
||||
|
||||
var blockElements = A_block_length + B_block_length;
|
||||
if (M % 8 != 0) && (N % 8 != 0) {
|
||||
let C_block_length = M_group * N_group;
|
||||
blockElements = max(C_block_length, blockElements)
|
||||
}
|
||||
if fused_bias {
|
||||
if D_trans {
|
||||
blockElements = max(blockElements, M_group)
|
||||
} else {
|
||||
blockElements = max(blockElements, N_group)
|
||||
}
|
||||
}
|
||||
// let blockBytes = blockElements * UInt16(dataType.size)
|
||||
let elementSize = 4
|
||||
let blockBytes = blockElements * UInt16(elementSize)
|
||||
|
||||
func ceilDivide(target: Int, granularity: UInt16) -> Int {
|
||||
(target + Int(granularity) - 1) / Int(granularity)
|
||||
}
|
||||
var gridSize = MTLSize(
|
||||
width: ceilDivide(target: N, granularity: N_group),
|
||||
height: ceilDivide(target: M, granularity: M_group),
|
||||
depth: 1)
|
||||
let groupSize = MTLSize(
|
||||
width: Int(32 * M_splits * N_splits),
|
||||
height: 1,
|
||||
depth: 1)
|
||||
|
||||
let commandQueue = device.makeCommandQueue()!
|
||||
|
||||
let threadgroupMemoryLength = blockBytes;
|
||||
|
||||
let rowsA = M;
|
||||
let columnsA = K;
|
||||
let rowsB = K;
|
||||
let columnsB = N;
|
||||
let rowsC = M;
|
||||
let columnsC = N;
|
||||
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
|
||||
|
||||
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
|
||||
|
||||
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
for i in 0..<arrayA.count {
|
||||
arrayA[i] = Float(i)
|
||||
}
|
||||
|
||||
for i in 0..<arrayB.count {
|
||||
arrayB[i] = Float(i)
|
||||
}
|
||||
|
||||
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
|
||||
let pipeline = try device.makeComputePipelineState(function: function)
|
||||
|
||||
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
|
||||
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
|
||||
encoder.setComputePipelineState(pipeline)
|
||||
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
|
||||
|
||||
encoder.setBuffer(bufferA, offset: 0, index: 0)
|
||||
encoder.setBuffer(bufferB, offset: 0, index: 1)
|
||||
encoder.setBuffer(bufferC, offset: 0, index: 2)
|
||||
let gridZ: Int = B
|
||||
if batched{
|
||||
func byteStride(shape: [Int]) -> Int {
|
||||
let rank = shape.count
|
||||
var output = elementSize * shape[rank - 2] * shape[rank - 1]
|
||||
if shape.dropLast(2).reduce(1, *) == 1 {
|
||||
output = 0
|
||||
}
|
||||
return output
|
||||
}
|
||||
let byteStrideA = M*K*elementSize
|
||||
let byteStrideB = N*K*elementSize
|
||||
let byteStrideC = M*N*elementSize
|
||||
|
||||
let byteStrideD = 0
|
||||
withUnsafeTemporaryAllocation(
|
||||
of: SIMD4<UInt64>.self, capacity: gridZ
|
||||
) { buffer in
|
||||
for i in 0..<buffer.count {
|
||||
buffer[i] = SIMD4(
|
||||
UInt64(truncatingIfNeeded: i * byteStrideA),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideB),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideC),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||
}
|
||||
|
||||
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
|
||||
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
|
||||
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||
}
|
||||
}
|
||||
gridSize.depth = gridZ
|
||||
|
||||
|
||||
encoder.dispatchThreadgroups(
|
||||
gridSize, threadsPerThreadgroup: groupSize
|
||||
)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
var commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
|
||||
commandBuffer.commit()
|
||||
commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
commandBuffer.encodeWaitForEvent(event, value: 2)
|
||||
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
|
||||
commandBuffer.commit()
|
||||
|
||||
commandBuffer.waitUntilCompleted()
|
||||
var contents = bufferC.contents();
|
||||
var count = B * rowsA * columnsB;
|
||||
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
print("First matmul is OK", Array(bufferedPointer))
|
||||
|
||||
contents = bufferD.contents();
|
||||
count = B * rowsA * columnsB;
|
||||
typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
print("This should be filled", Array(bufferedPointer))
|
@ -312,7 +312,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
"affine_f32",
|
||||
size,
|
||||
&input,
|
||||
&output,
|
||||
@ -346,7 +346,7 @@ fn run_affine_strided<T: Clone>(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float_strided",
|
||||
"affine_f32_strided",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
@ -574,12 +574,15 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
@ -608,6 +611,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -622,7 +626,7 @@ fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
@ -631,7 +635,7 @@ fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
@ -639,7 +643,7 @@ fn reduce_sum2() {
|
||||
fn softmax() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
@ -651,7 +655,7 @@ fn softmax() {
|
||||
for i in 0..n {
|
||||
v[i * last_dim] = 20.0;
|
||||
}
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
let results = approx(results, 4);
|
||||
println!("{results:?}");
|
||||
assert_eq!(
|
||||
@ -665,7 +669,7 @@ fn softmax() {
|
||||
|
||||
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
@ -673,7 +677,7 @@ fn softmax() {
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 3;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
@ -684,7 +688,7 @@ fn softmax() {
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||
let results = run_softmax(&v, last_dim, "softmax_f16");
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
@ -695,7 +699,7 @@ fn softmax() {
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||
let results = run_softmax(&v, last_dim, "softmax_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||
|
@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
|
||||
UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
@ -108,8 +108,8 @@ UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
#endif
|
||||
|
@ -210,32 +210,33 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
||||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
|
@ -142,9 +142,10 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
Ok(Self { sin, cos })
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
@ -407,38 +408,3 @@ impl MixFormerSequentialForCausalLM {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_rotary() {
|
||||
let dev = Device::new_metal(0).unwrap();
|
||||
for i in 0..10000 {
|
||||
let dim = 8;
|
||||
let max_seq_len = 12;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap();
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, &dev)
|
||||
.unwrap()
|
||||
.to_dtype(DType::F32)
|
||||
.unwrap()
|
||||
.reshape((max_seq_len, 1))
|
||||
.unwrap();
|
||||
let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 1.0);
|
||||
let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let freqs = t.matmul(&inv_freq).unwrap();
|
||||
let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let sin = freqs.sin().unwrap().contiguous().unwrap();
|
||||
let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.099833414);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user