BF16 metal fix.

This commit is contained in:
Nicolas Patry
2023-11-13 14:44:20 +01:00
parent dd4a40f1c0
commit 1471f98f0b
3 changed files with 24 additions and 45 deletions

View File

@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use core::mem; use core::mem;
use half::{bf16, f16}; use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@ -89,6 +89,15 @@ impl MetalDevice {
self.device self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged) .new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
(data.len() * mem::size_of::<T>()) as NSUInteger,
option,
)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let start = std::time::Instant::now(); // let start = std::time::Instant::now();
self.device.wait_until_completed(); self.device.wait_until_completed();
println!("Wait took {:?}", start.elapsed()); // println!("Wait took {:?}", start.elapsed());
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
self.device.wait_until_completed();
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage {
metal::mps::MPS_FLOATBIT_ENCODING | 16, metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger, core::mem::size_of::<f16>() as NSUInteger,
), ),
dtype => todo!("Implement matmul {dtype:?}"), dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
}; };
let elem_count = b * m * n; let elem_count = b * m * n;
@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice {
} }
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage { let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data( CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
storage.as_ptr() as *const core::ffi::c_void, CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
(storage.len() * mem::size_of::<u8>()) as NSUInteger, CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
option, CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
), CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
CpuStorage::U32(storage) => self.device.new_buffer_with_data( CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
storage.as_ptr() as *const core::ffi::c_void, CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
option,
),
}; };
Ok(Self::Storage { Ok(Self::Storage {
buffer, buffer,

View File

@ -60,8 +60,8 @@ impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes( encoder.set_bytes(
position, position,
(core::mem::size_of::<T>() * data.len()) as u64, core::mem::size_of_val(data) as u64,
data.as_ptr() as *const T as *const c_void, data.as_ptr() as *const c_void,
); );
} }
} }
@ -190,7 +190,7 @@ type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>; type Libraries = HashMap<Source, Library>;
type Pipelines = KernelMap<ComputePipelineState>; type Pipelines = KernelMap<ComputePipelineState>;
#[derive(Debug)] #[derive(Debug, Default)]
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,

View File

@ -43,7 +43,7 @@ template <typename T> METAL_FUNC T erf(T in){
return T(sign*y); return T(sign*y);
} }
template <typename T> METAL_FUNC T id(T in){ return in; } template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; } template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){ template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x; T x_sq = x * x;
T x_cube = x_sq * x; T x_cube = x_sq * x;