BF16 metal fix.

This commit is contained in:
Nicolas Patry
2023-11-13 14:44:20 +01:00
parent dd4a40f1c0
commit 1471f98f0b
3 changed files with 24 additions and 45 deletions

View File

@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use core::mem;
use half::{bf16, f16};
use half::f16;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock};
@ -89,6 +89,15 @@ impl MetalDevice {
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
(data.len() * mem::size_of::<T>()) as NSUInteger,
option,
)
}
}
#[derive(Debug, Clone)]
@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let start = std::time::Instant::now();
// let start = std::time::Instant::now();
self.device.wait_until_completed();
println!("Wait took {:?}", start.elapsed());
// println!("Wait took {:?}", start.elapsed());
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
self.device.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage {
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
),
dtype => todo!("Implement matmul {dtype:?}"),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
let elem_count = b * m * n;
@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
option,
),
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
Ok(Self::Storage {
buffer,

View File

@ -60,8 +60,8 @@ impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
(core::mem::size_of::<T>() * data.len()) as u64,
data.as_ptr() as *const T as *const c_void,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
@ -190,7 +190,7 @@ type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
type Pipelines = KernelMap<ComputePipelineState>;
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,

View File

@ -43,7 +43,7 @@ template <typename T> METAL_FUNC T erf(T in){
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;