mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Broken metal ?
This commit is contained in:
@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::{void_ptr, Kernels, AFFINE};
|
||||
use candle_metal_kernels::{void_ptr, Kernels};
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
@ -36,8 +36,7 @@ impl MetalError {
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
_command_queue: metal::CommandQueue,
|
||||
command_buffer: metal::CommandBuffer,
|
||||
command_queue: metal::CommandQueue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
@ -66,13 +65,14 @@ impl MetalDevice {
|
||||
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||
self.device.new_buffer(size, MTLResourceOptions::empty())
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: Arc<metal::Buffer>,
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let command_buffer = self.device.command_queue.new_owned_command_buffer();
|
||||
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
@ -123,7 +124,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let src_length = self.buffer.length() as usize - layout.start_offset();
|
||||
let src = self.device.new_buffer(src_length, self.dtype);
|
||||
let blit_encoder = self.device.command_buffer.new_blit_command_encoder();
|
||||
let blit_encoder = command_buffer.new_blit_command_encoder();
|
||||
blit_encoder.copy_from_buffer(
|
||||
self.buffer.as_ref(),
|
||||
layout.start_offset() as NSUInteger,
|
||||
@ -133,7 +134,7 @@ impl BackendStorage for MetalStorage {
|
||||
);
|
||||
blit_encoder.end_encoding();
|
||||
|
||||
let encoder = device.command_buffer.new_compute_command_encoder();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
@ -164,6 +165,10 @@ impl BackendStorage for MetalStorage {
|
||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_completed();
|
||||
println!("Affine");
|
||||
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
@ -190,17 +195,45 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
//todo!("Implement the kernel calling");
|
||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
}
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_completed();
|
||||
println!("Unary {:?}", B::KERNEL);
|
||||
|
||||
Ok(Self {
|
||||
buffer: Arc::new(buffer),
|
||||
device,
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
@ -368,7 +401,7 @@ impl MetalStorage {
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
return Ok(Self {
|
||||
buffer: Arc::new(out_buffer),
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
@ -380,15 +413,17 @@ impl MetalStorage {
|
||||
rhs_l.is_contiguous()
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: Arc::new(out_buffer),
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
|
||||
println!("GEMM");
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
encode_gemm::<Float32, Float32, Float32>(
|
||||
&self.device,
|
||||
&self.device.command_buffer,
|
||||
&command_buffer,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
&self.buffer,
|
||||
@ -402,13 +437,15 @@ impl MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||
println!("out {:?} {m} {n}", out_buffer.length());
|
||||
println!("lhs {:?}", lhs_l.shape());
|
||||
command_buffer.commit();
|
||||
|
||||
// println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||
// println!("out {:?} {m} {n}", out_buffer.length());
|
||||
// println!("lhs {:?}", lhs_l.shape());
|
||||
|
||||
Ok(Self {
|
||||
buffer: Arc::new(out_buffer),
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
@ -423,13 +460,13 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
let _command_queue = device.new_command_queue();
|
||||
let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::init(&device).map_err(MetalError::from)?);
|
||||
let command_queue = device.new_command_queue();
|
||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
_command_queue,
|
||||
command_buffer,
|
||||
command_queue,
|
||||
// command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
@ -498,7 +535,7 @@ impl BackendDevice for MetalDevice {
|
||||
),
|
||||
};
|
||||
Ok(Self::Storage {
|
||||
buffer: Arc::new(buffer),
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
|
@ -1,8 +1,7 @@
|
||||
use metal::{
|
||||
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||
MTLSize,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
@ -11,11 +10,18 @@ const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
}
|
||||
|
||||
macro_rules! unary{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
pub struct Kernel(pub &'static str);
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -27,7 +33,7 @@ macro_rules! unary{
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
pub struct Kernel(pub &'static str);
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -41,17 +47,17 @@ macro_rules! unary{
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
unary!(cos, sin, exp);
|
||||
unary!(cos, sin, exp, sqr, sqrt, neg);
|
||||
}
|
||||
|
||||
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||
let mut l = HashMap::new();
|
||||
l.insert("affine", AFFINE);
|
||||
l.insert("indexing", INDEXING);
|
||||
l.insert("unary", UNARY);
|
||||
l
|
||||
});
|
||||
|
||||
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||
// let mut l = HashMap::new();
|
||||
// l.insert("affine", AFFINE);
|
||||
// l.insert("indexing", INDEXING);
|
||||
// l.insert("unary", UNARY);
|
||||
// l
|
||||
// });
|
||||
//
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalKernelError {
|
||||
#[error("Could not lock kernel map: {0}")]
|
||||
@ -69,7 +75,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
}
|
||||
|
||||
type KernelMap<T> = HashMap<&'static str, T>;
|
||||
type Libraries = KernelMap<Library>;
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Functions = KernelMap<Function>;
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -85,39 +91,42 @@ impl Kernels {
|
||||
Self { libraries, funcs }
|
||||
}
|
||||
|
||||
pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||
let kernels = Self::new();
|
||||
kernels.load_libraries(device)?;
|
||||
Ok(kernels)
|
||||
}
|
||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||
// let kernels = Self::new();
|
||||
// kernels.load_libraries(device)?;
|
||||
// Ok(kernels)
|
||||
// }
|
||||
|
||||
fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||
for name in LIBRARY_SOURCES.keys() {
|
||||
self.load_library(device, name)?;
|
||||
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||
// for name in LIBRARY_SOURCES.keys() {
|
||||
// self.load_library(device, name)?;
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
// LIBRARY_SOURCES.get(name).cloned()
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
Source::Indexing => INDEXING,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_library_source(&self, name: &'static str) -> Option<&'static str> {
|
||||
LIBRARY_SOURCES.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn load_library(
|
||||
&self,
|
||||
device: &Device,
|
||||
name: &'static str,
|
||||
source: Source,
|
||||
) -> Result<Library, MetalKernelError> {
|
||||
let mut libraries = self.libraries.write()?;
|
||||
if let Some(lib) = libraries.get(name) {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let source = self.get_library_source(name).ok_or_else(|| {
|
||||
MetalKernelError::LoadLibraryError(format!("No source found for {}", name))
|
||||
})?;
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source, &CompileOptions::new())
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
libraries.insert(name, lib.clone());
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
}
|
||||
@ -125,7 +134,7 @@ impl Kernels {
|
||||
pub fn load_function(
|
||||
&self,
|
||||
device: &Device,
|
||||
library_name: &'static str,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let mut funcs = self.funcs.write()?;
|
||||
@ -133,7 +142,7 @@ impl Kernels {
|
||||
Ok(func.clone())
|
||||
} else {
|
||||
let func = self
|
||||
.load_library(device, library_name)?
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
funcs.insert(name, func.clone());
|
||||
@ -144,15 +153,16 @@ impl Kernels {
|
||||
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBuffer,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, "unary", kernel_name.0)?;
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
@ -188,7 +198,7 @@ pub fn call_unary_contiguous(
|
||||
}
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBuffer,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
input: &Buffer,
|
||||
@ -197,7 +207,7 @@ pub fn call_unary_strided(
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, "unary", name.0)?;
|
||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
@ -277,7 +287,7 @@ mod tests {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_owned_command_buffer();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
@ -310,7 +320,7 @@ mod tests {
|
||||
let device = device();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_owned_command_buffer();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
|
@ -22,6 +22,9 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -59,12 +62,26 @@ kernel void FN_NAME_STRIDED( \
|
||||
} \
|
||||
}
|
||||
|
||||
UNARY(cos, float, cos_float, cos_float_strided);
|
||||
UNARY(cos, half, cos_half, cos_half_strided);
|
||||
UNARY(sin, float, sin_float, sin_float_strided);
|
||||
UNARY(sin, half, sin_half, sin_half_strided);
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
||||
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided);
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user