From 1471f98f0b67fdd414bc3b36ce9c039b44c14ccf Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 Nov 2023 14:44:20 +0100 Subject: [PATCH] BF16 metal fix. --- candle-core/src/metal_backend.rs | 61 +++++++++------------------- candle-metal-kernels/src/lib.rs | 6 +-- candle-metal-kernels/src/unary.metal | 2 +- 3 files changed, 24 insertions(+), 45 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index efce19c1..6da2e2a9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; use core::mem; -use half::{bf16, f16}; +use half::f16; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::{Arc, RwLock}; @@ -89,6 +89,15 @@ impl MetalDevice { self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } + + pub fn new_buffer_with_data(&self, data: &[T]) -> Buffer { + let option = metal::MTLResourceOptions::StorageModeManaged; + self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + (data.len() * mem::size_of::()) as NSUInteger, + option, + ) + } } #[derive(Debug, Clone)] @@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - let start = std::time::Instant::now(); + // let start = std::time::Instant::now(); self.device.wait_until_completed(); - println!("Wait took {:?}", start.elapsed()); + // println!("Wait took {:?}", start.elapsed()); match self.dtype { DType::U8 => Ok(CpuStorage::U8( @@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - self.device.wait_until_completed(); Ok(Self { buffer, device: device.clone(), @@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage { metal::mps::MPS_FLOATBIT_ENCODING | 16, core::mem::size_of::() as NSUInteger, ), - dtype => todo!("Implement matmul {dtype:?}"), + dtype => todo!("Dtype for matmul {dtype:?} is not supported"), }; let elem_count = b * m * n; @@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }; Ok(Self::Storage { buffer, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6b2ab050..a9d108f4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -60,8 +60,8 @@ impl EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -190,7 +190,7 @@ type KernelMap = HashMap<&'static str, T>; type Libraries = HashMap; type Pipelines = KernelMap; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Kernels { libraries: RwLock, pipelines: RwLock, diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 5389a26b..88139af9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -43,7 +43,7 @@ template METAL_FUNC T erf(T in){ return T(sign*y); } template METAL_FUNC T id(T in){ return in; } -template METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; } +template METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } template METAL_FUNC T gelu(T x){ T x_sq = x * x; T x_cube = x_sq * x;