mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
BF16 metal fix.
This commit is contained in:
@ -5,7 +5,7 @@ use crate::{CpuStorage, DType, Layout, Result, Shape};
|
|||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::f16;
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -89,6 +89,15 @@ impl MetalDevice {
|
|||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||||
|
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
data.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(data.len() * mem::size_of::<T>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -114,9 +123,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
let start = std::time::Instant::now();
|
// let start = std::time::Instant::now();
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
println!("Wait took {:?}", start.elapsed());
|
// println!("Wait took {:?}", start.elapsed());
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
@ -415,7 +424,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -688,7 +696,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||||
core::mem::size_of::<f16>() as NSUInteger,
|
core::mem::size_of::<f16>() as NSUInteger,
|
||||||
),
|
),
|
||||||
dtype => todo!("Implement matmul {dtype:?}"),
|
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
@ -916,43 +924,14 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
|
||||||
let buffer = match storage {
|
let buffer = match storage {
|
||||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||||
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||||
option,
|
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||||
),
|
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||||
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
|
||||||
option,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer,
|
buffer,
|
||||||
|
@ -60,8 +60,8 @@ impl<T> EncoderParam for &[T] {
|
|||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
encoder.set_bytes(
|
encoder.set_bytes(
|
||||||
position,
|
position,
|
||||||
(core::mem::size_of::<T>() * data.len()) as u64,
|
core::mem::size_of_val(data) as u64,
|
||||||
data.as_ptr() as *const T as *const c_void,
|
data.as_ptr() as *const c_void,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,7 +190,7 @@ type KernelMap<T> = HashMap<&'static str, T>;
|
|||||||
type Libraries = HashMap<Source, Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Pipelines = KernelMap<ComputePipelineState>;
|
type Pipelines = KernelMap<ComputePipelineState>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
pipelines: RwLock<Pipelines>,
|
pipelines: RwLock<Pipelines>,
|
||||||
|
@ -43,7 +43,7 @@ template <typename T> METAL_FUNC T erf(T in){
|
|||||||
return T(sign*y);
|
return T(sign*y);
|
||||||
}
|
}
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
|
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||||
template <typename T> METAL_FUNC T gelu(T x){
|
template <typename T> METAL_FUNC T gelu(T x){
|
||||||
T x_sq = x * x;
|
T x_sq = x * x;
|
||||||
T x_cube = x_sq * x;
|
T x_cube = x_sq * x;
|
||||||
|
Reference in New Issue
Block a user