mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
BF16 metal fix.
This commit is contained in:
@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -89,6 +89,15 @@ impl MetalDevice {
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
(data.len() * mem::size_of::<T>()) as NSUInteger,
|
||||
option,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let start = std::time::Instant::now();
|
||||
// let start = std::time::Instant::now();
|
||||
self.device.wait_until_completed();
|
||||
println!("Wait took {:?}", start.elapsed());
|
||||
// println!("Wait took {:?}", start.elapsed());
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage {
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||
core::mem::size_of::<f16>() as NSUInteger,
|
||||
),
|
||||
dtype => todo!("Implement matmul {dtype:?}"),
|
||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||
};
|
||||
|
||||
let elem_count = b * m * n;
|
||||
@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
|
Reference in New Issue
Block a user