mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Starting to fix some tests.
Few fixes.
Going back on remote metal-rs.
Reusing a single buffer (for now) to speed things up.
Adding some half kernels.
All tests are panicking instead of random failure.
Putting back f16 index select.
Add erf.
Working version for llama2-c.
Fixes + cache compute_pipeline_state.
BF16 metal fix.
Remove some prints.
new_owned -> new()..to_owned().
Better batched matmul.
Metal operational.
Reuse buffers on our own reference counts.
Tmp gemm.
Revert "Tmp gemm."
This reverts commit c65f68e988
.
Interleave committing.
Speeding up copies using blit.
Fmt.
Fmt.
Remove the assert!
Fmt all.
Fixes after big rebase.
Add softmax for half and bfloat + tests
Fixing Llama example + accumulate softmax in float.
This commit is contained in:
@ -33,6 +33,24 @@ kernel void FN_NAME( \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = gid / right_size / left_size; \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid % left_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
@ -75,6 +75,7 @@ kernel void FN_NAME( \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
|
@ -1,6 +1,6 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, Library, MTLSize,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -59,8 +59,8 @@ impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
(core::mem::size_of::<T>() * data.len()) as u64,
|
||||
data.as_ptr() as *const T as *const c_void,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -111,13 +111,7 @@ macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -126,16 +120,18 @@ macro_rules! ops{
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
||||
pub const HALF: Kernel = Kernel("copy_half");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||
pub const U32: Kernel = Kernel("copy_u32");
|
||||
pub const U8: Kernel = Kernel("copy_u8");
|
||||
}
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -144,12 +140,20 @@ macro_rules! ops{
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
@ -161,8 +165,12 @@ pub enum MetalKernelError {
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0}")]
|
||||
#[error("Error while loading function: {0:?}")]
|
||||
LoadFunctionError(String),
|
||||
#[error("Failed to create compute function")]
|
||||
FailedToCreateComputeFunction,
|
||||
#[error("Failed to create pipeline")]
|
||||
FailedToCreatePipeline(String),
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
@ -173,19 +181,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
|
||||
type KernelMap<T> = HashMap<&'static str, T>;
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Functions = KernelMap<Function>;
|
||||
type Pipelines = KernelMap<ComputePipelineState>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
funcs: RwLock<Functions>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let funcs = RwLock::new(Functions::new());
|
||||
Self { libraries, funcs }
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
@ -218,22 +229,43 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_function(
|
||||
fn load_function(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let mut funcs = self.funcs.write()?;
|
||||
if let Some(func) = funcs.get(name) {
|
||||
Ok(func.clone())
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
Ok(func)
|
||||
// let mut funcs = self.funcs.write()?;
|
||||
// if let Some(func) = funcs.get(name) {
|
||||
// Ok(func.clone())
|
||||
// } else {
|
||||
// funcs.insert(name, func.clone());
|
||||
// Ok(func)
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn load_pipeline(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
let mut pipelines = self.pipelines.write()?;
|
||||
if let Some(pipeline) = pipelines.get(name) {
|
||||
Ok(pipeline.clone())
|
||||
} else {
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
funcs.insert(name, func.clone());
|
||||
Ok(func)
|
||||
let func = self.load_function(device, source, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
||||
pipelines.insert(name, pipeline.clone());
|
||||
|
||||
Ok(pipeline)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -246,18 +278,9 @@ pub fn call_unary_contiguous(
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -279,18 +302,10 @@ pub fn call_unary_strided(
|
||||
input: &Buffer,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -326,17 +341,9 @@ pub fn call_binary_contiguous(
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -363,17 +370,9 @@ pub fn call_binary_strided(
|
||||
right_input: &Buffer,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -411,22 +410,52 @@ pub fn call_cast_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_cast_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
@ -435,7 +464,6 @@ pub fn call_cast_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_reduce_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -444,24 +472,19 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -495,18 +518,9 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -542,21 +556,14 @@ pub fn call_affine(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -570,6 +577,45 @@ pub fn call_affine(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
add,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_where_cond_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -582,17 +628,9 @@ pub fn call_where_cond_strided(
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Ternary, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -634,17 +672,14 @@ pub fn call_index_select(
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
ids: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids_size * left_size * right_size;
|
||||
|
||||
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
@ -45,10 +47,10 @@ kernel void NAME( \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -56,10 +58,10 @@ kernel void NAME( \
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -68,72 +70,74 @@ kernel void NAME( \
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
while (idx < stop_idx) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const T val = T(exp(src[idx] - _max)); \
|
||||
dst[idx] = val; \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
#endif
|
||||
|
@ -32,6 +32,9 @@ kernel void FN_NAME( \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
|
@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
@ -23,13 +23,18 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -37,7 +42,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -53,7 +58,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = new_buffer(&device, x);
|
||||
let right = new_buffer(&device, y);
|
||||
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -62,7 +67,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -81,7 +86,7 @@ fn run_strided<T: Clone>(
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
@ -92,7 +97,7 @@ fn run_strided<T: Clone>(
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&mut output,
|
||||
&output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
@ -220,7 +225,9 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
||||
let output = device.new_buffer(size, options);
|
||||
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
@ -229,7 +236,8 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -245,11 +253,17 @@ fn cast_u32_f32() {
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
@ -259,7 +273,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
@ -267,9 +281,45 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn _run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_affine_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -295,6 +345,16 @@ fn affine() {
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn affine_strided() {
|
||||
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
// let mul = 1.5;
|
||||
// let add = 1.1;
|
||||
// let result = run_affine_(&input, mul, add);
|
||||
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn index_select() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
@ -313,7 +373,26 @@ fn index_select() {
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
@ -321,7 +400,7 @@ fn index_select() {
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
);
|
||||
}
|
||||
|
||||
@ -341,20 +420,26 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
"is_u32_f32",
|
||||
name,
|
||||
shape,
|
||||
ids.len(),
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&mut dst_buffer,
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -451,7 +536,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -460,7 +545,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
&mut output,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -475,7 +561,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_last_softmax(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -484,7 +570,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -536,6 +622,28 @@ fn softmax() {
|
||||
approx(results, 4),
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
@ -571,7 +679,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
options,
|
||||
);
|
||||
|
||||
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -584,7 +692,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
|
@ -1,4 +1,7 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||
template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -64,8 +96,16 @@ UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
Reference in New Issue
Block a user