Compare commits

...

13 Commits

9 changed files with 250 additions and 28 deletions

View File

@ -3,6 +3,7 @@ mod benchmarks;
use criterion::criterion_main; use criterion::criterion_main;
criterion_main!( criterion_main!(
benchmarks::affine::benches, benchmarks::affine::benches,
benchmarks::fill::benches,
benchmarks::matmul::benches, benchmarks::matmul::benches,
benchmarks::random::benches, benchmarks::random::benches,
benchmarks::where_cond::benches benchmarks::where_cond::benches

View File

@ -0,0 +1,44 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) {
Tensor::ones(shape, dtype, device).unwrap();
}
fn run_fill_benchmark(c: &mut Criterion, device: &Device, name: &str, dtype: DType) {
let b = 1;
let rows = 1024;
let columns = 1024;
let flops = b * rows * columns * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |bencher| {
bencher.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box((b, rows, columns)),
black_box(dtype),
black_box(&device),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_fill_benchmark(c, &device, "fill_u8", DType::U8);
run_fill_benchmark(c, &device, "fill_f32", DType::F32);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,4 +1,5 @@
pub(crate) mod affine; pub(crate) mod affine;
pub(crate) mod fill;
pub(crate) mod matmul; pub(crate) mod matmul;
pub(crate) mod random; pub(crate) mod random;
pub(crate) mod where_cond; pub(crate) mod where_cond;

View File

@ -4,6 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use half::{bf16, f16};
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
@ -1591,9 +1592,41 @@ impl BackendDevice for MetalDevice {
} }
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ? let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?;
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; let command_buffer = self.command_buffer()?;
self.storage_from_cpu_storage(&cpu_storage) command_buffer.set_label("ones");
macro_rules! fill {
($value:expr) => {
candle_metal_kernels::call_fill(
&self.device,
&command_buffer,
&self.kernels,
shape.elem_count(),
&buffer,
$value,
)
.map_err(MetalError::from)?
};
}
match dtype {
DType::U8 => candle_metal_kernels::call_fill_u8(
&command_buffer,
shape.elem_count(),
&buffer,
1u8,
)
.map_err(MetalError::from)?,
DType::U32 => fill!(1u32),
DType::I64 => fill!(1i64),
DType::BF16 => fill!(bf16::ONE),
DType::F16 => fill!(f16::ONE),
DType::F32 => fill!(1f32),
DType::F64 => {
return Err(MetalError::Message("Metal doesn't support double".to_string()).into())
}
}
Ok(MetalStorage::new(buffer, self.clone(), dtype))
} }
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {

View File

@ -9,17 +9,13 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"] categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { version = "0.27.0", features = ["mps"] } metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0" once_cell = "1.18.0"
thiserror = "1" thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
num-traits = "0.2.17"
[dev-dependencies] [dev-dependencies]
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5" rand = "0.8.5"

View File

@ -0,0 +1,34 @@
#include <metal_stdlib>
using namespace metal;
template<typename T>
void fill(
device T *buffer [[buffer(0)]],
constant T &value,
constant size_t &numel,
uint gid [[thread_position_in_grid]]
) {
if (gid >= numel) return;
buffer[gid] = value;
}
#define FILL_OP(T, FN_NAME) \
kernel void FN_NAME( \
device T *buffer [[buffer(0)]], \
constant T &value, \
constant size_t &numel, \
uint gid [[thread_position_in_grid]] \
) { fill<T>(buffer, value, numel, gid); } \
FILL_OP(uint8_t, fill_u8)
FILL_OP(uint32_t, fill_u32)
FILL_OP(half, fill_f16)
FILL_OP(float, fill_f32)
#if __METAL_VERSION__ >= 220
FILL_OP(int64_t, fill_i64)
#endif
#if defined(__HAVE_BFLOAT__)
FILL_OP(bfloat, fill_bf16)
#endif

View File

@ -1,3 +1,4 @@
use half::{bf16, f16};
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal"); const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const FILL: &str = include_str!("fill.metal");
const CONV: &str = include_str!("conv.metal"); const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal"); const RANDOM: &str = include_str!("random.metal");
@ -47,29 +49,26 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
/// Helper functions to create the various objects on the compute command encoder /// Helper functions to create the various objects on the compute command encoder
/// on a single line. /// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes. /// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam { pub trait EncoderParam: private::Sealed {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
} }
macro_rules! primitive {
($type:ty) => { macro_rules! primitives {
impl EncoderParam for $type { ($($type:ty),+) => {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { $(
encoder.set_bytes( impl EncoderParam for $type {
position, fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
core::mem::size_of::<$type>() as u64, encoder.set_bytes(
&data as *const $type as *const c_void, position,
); core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
} }
} )+
}; };
} }
primitive!(bool); primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
impl<T> EncoderParam for &[T] { impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
@ -112,6 +111,22 @@ macro_rules! set_params {
); );
} }
// Seal for EncoderParam so that only the types we want can implement it
mod private {
use super::*;
pub trait Sealed {}
macro_rules! sealed {
($($type:ty),+) => {
$(impl Sealed for $type {})+
};
}
sealed!(usize, u8, u32, u64, i32, i64, f16, bf16, f32, bool);
sealed!(&Buffer, (&Buffer, usize), &mut Buffer, (&mut Buffer, usize));
impl<T> Sealed for &[T] {}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source { pub enum Source {
Affine, Affine,
@ -123,6 +138,7 @@ pub enum Source {
Reduce, Reduce,
Mfa, Mfa,
Conv, Conv,
Fill,
Random, Random,
Quantized, Quantized,
} }
@ -192,6 +208,8 @@ pub mod binary {
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum MetalKernelError { pub enum MetalKernelError {
#[error("Invalid usage of kernel: {0}")]
InvalidUsage(String),
#[error("Could not lock kernel map: {0}")] #[error("Could not lock kernel map: {0}")]
LockError(String), LockError(String),
#[error("Error while loading library: {0}")] #[error("Error while loading library: {0}")]
@ -244,6 +262,7 @@ impl Kernels {
Source::Indexing => INDEXING, Source::Indexing => INDEXING,
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Fill => FILL,
Source::Conv => CONV, Source::Conv => CONV,
Source::Random => RANDOM, Source::Random => RANDOM,
Source::Quantized => QUANTIZED, Source::Quantized => QUANTIZED,
@ -1769,9 +1788,68 @@ pub fn call_quantized_matmul_t(
Ok(()) Ok(())
} }
#[inline(always)]
fn divide(m: usize, b: usize) -> NSUInteger { fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger ((m + b - 1) / b) as NSUInteger
} }
pub fn call_fill<T: FillOp>(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
elem_count: usize,
buffer: &Buffer,
value: T,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);
set_params!(encoder, (buffer, value, elem_count));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.end_encoding();
Ok(())
}
pub fn call_fill_u8(
command_buffer: &CommandBufferRef,
elem_count: usize,
buffer: &Buffer,
value: u8,
) -> Result<(), MetalKernelError> {
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
buffer,
metal::NSRange {
location: 0,
length: elem_count as NSUInteger,
},
value,
);
blit.end_encoding();
Ok(())
}
pub trait FillOp: EncoderParam {
const FILL_KERNEL: &'static str;
}
macro_rules ! impl_call_fill {
($($t:ty),*) => {
$(
impl FillOp for $t {
const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t));
}
)*
};
}
impl_call_fill!(u8, u32, i64, f16, bf16, f32);
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View File

@ -927,6 +927,42 @@ fn gemm() {
); );
} }
fn run_fill<T: FillOp + Clone>(elem_count: usize, value: T) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
call_fill(
&device,
command_buffer,
&kernels,
elem_count,
&buffer,
value,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&buffer, elem_count)
}
#[test]
fn fill() {
fn assert_fill<T: FillOp + Copy + std::fmt::Debug + PartialEq>(value: T) {
for i in 0..4 {
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
}
}
assert_fill(123u8);
assert_fill(456u32);
assert_fill(789i64);
assert_fill(f16::from_f32(1.23));
assert_fill(bf16::from_f32(4.56));
assert_fill(7.89f32);
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> { fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let kernels = Kernels::new();

View File

@ -1,6 +1,5 @@
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_math> #include <metal_math>
#
using namespace metal; using namespace metal;
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(