BF16 metal fix.

This commit is contained in:
Nicolas Patry
2023-11-13 14:44:20 +01:00
parent dd4a40f1c0
commit 1471f98f0b
3 changed files with 24 additions and 45 deletions

View File

@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use core::mem;
use half::{bf16, f16};
use half::f16;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock};
@ -89,6 +89,15 @@ impl MetalDevice {
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
(data.len() * mem::size_of::<T>()) as NSUInteger,
option,
)
}
}
#[derive(Debug, Clone)]
@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let start = std::time::Instant::now();
// let start = std::time::Instant::now();
self.device.wait_until_completed();
println!("Wait took {:?}", start.elapsed());
// println!("Wait took {:?}", start.elapsed());
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
self.device.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage {
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
),
dtype => todo!("Implement matmul {dtype:?}"),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
let elem_count = b * m * n;
@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
option,
),
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
Ok(Self::Storage {
buffer,