From 7fc26764b6fb95429c9130c6122a2f4c2038d1ba Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:02:29 +0100 Subject: [PATCH] Implement generic fill. u8 uses speedy blit encoder --- candle-core/src/metal_backend.rs | 58 +++++++++++++++++------- candle-metal-kernels/Cargo.toml | 1 + candle-metal-kernels/src/lib.rs | 75 ++++++++++++++++++++++++++++++- candle-metal-kernels/src/tests.rs | 35 +++++++-------- 4 files changed, 134 insertions(+), 35 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6d8afab1..f2b55d4e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,8 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::Kernels; +use candle_metal_kernels::{CallFill, Fill, Kernels}; +use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1403,25 +1404,52 @@ impl BackendDevice for MetalDevice { let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.fill_buffer( + + // This assumes the specific zero type DType is equal to 0x00u8 + // (which is true for all current types) + Fill::call_fill( + &self.device, + &command_buffer, + &self.kernels, + shape.elem_count(), &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.update_fence(&self.fence); - blit.end_encoding(); + 0u8, + ) + .map_err(MetalError::from)?; + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("ones"); + + macro_rules! fill { + ($value:expr) => { + Fill::call_fill( + &self.device, + &command_buffer, + &self.kernels, + shape.elem_count(), + &buffer, + $value, + ) + .map_err(MetalError::from)? + }; + } + match dtype { + DType::U8 => fill!(1u8), + DType::U32 => fill!(1u32), + DType::I64 => fill!(1i64), + DType::BF16 => fill!(bf16::ONE), + DType::F16 => fill!(f16::ONE), + DType::F32 => fill!(1f32), + DType::F64 => { + return Err(MetalError::Message(format!("metal doesn't support double")).into()) + } + } + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index ba09ffcb..6c64a8e5 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -15,6 +15,7 @@ once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +num-traits = "0.2.17" [dev-dependencies] rand = "0.8.5" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a2730632..e0985d94 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,6 +5,7 @@ use metal::{ }; use std::collections::HashMap; use std::ffi::c_void; +use std::marker::PhantomData; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -180,6 +181,8 @@ pub mod binary { #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { + #[error("Invalid usage of kernel: {0}")] + InvalidUsage(String), #[error("Could not lock kernel map: {0}")] LockError(String), #[error("Error while loading library: {0}")] @@ -1575,7 +1578,77 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } -pub fn call_fill( +pub struct Fill { + _marker: PhantomData, +} + +pub trait CallFill { + const KERNEL_NAME: &'static str; + + fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: T, + ) -> Result<(), MetalKernelError>; +} + +macro_rules ! impl_call_fill { + ($($t:ty),*) => { + $( + impl CallFill<$t> for Fill<$t> { + const KERNEL_NAME: &'static str = concat!("fill_", stringify!($t)); + + fn call_fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { + _call_fill(device, command_buffer, kernels, Self::KERNEL_NAME, elem_count, buffer, value) + } + } + )* + }; +} +impl_call_fill!(u32, i64, f16, bf16, f32); + +impl CallFill for Fill { + const KERNEL_NAME: &'static str = ""; + + fn call_fill( + _: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, + ) -> Result<(), MetalKernelError> { + _call_blit_fill(command_buffer, kernels, elem_count, buffer, value) + } +} + +fn _call_blit_fill( + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, +) -> Result<(), MetalKernelError> { + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&kernels.fence); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: elem_count as NSUInteger, + }, + value, + ); + blit.update_fence(&kernels.fence); + blit.end_encoding(); + + Ok(()) +} + +fn _call_fill( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c1c7b8ab..a4fb726f 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -590,7 +590,6 @@ fn softmax() { } let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); - println!("{results:?}"); assert_eq!( results.iter().map(|&s| s.round() as usize).sum::(), n @@ -807,22 +806,20 @@ fn gemm() { ); } -fn run_fill( - elem_count: usize, - value: T, - kernel_name: &'static str, -) -> Vec { +fn run_fill(elem_count: usize, value: T) -> Vec +where + Fill: CallFill, +{ let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); - call_fill( + Fill::::call_fill( &device, command_buffer, &kernels, - kernel_name, elem_count, &buffer, value, @@ -836,18 +833,18 @@ fn run_fill( #[test] fn fill() { - fn assert_fill( - value: T, - name: &'static str, - ) { + fn assert_fill(value: T) + where + Fill: CallFill, + { for i in 0..4 { - assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]); + assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]); } } - assert_fill(123u8, "fill_u8"); - assert_fill(456u32, "fill_u32"); - assert_fill(789i64, "fill_i64"); - assert_fill(f16::from_f32(1.23), "fill_f16"); - assert_fill(bf16::from_f32(4.56), "fill_bf16"); - assert_fill(7.89f32, "fill_f32"); + assert_fill(123u8); + assert_fill(456u32); + assert_fill(789i64); + assert_fill(f16::from_f32(1.23)); + assert_fill(bf16::from_f32(4.56)); + assert_fill(7.89f32); }