Broken metal ?

This commit is contained in:
Nicolas Patry
2023-11-07 14:20:13 +01:00
parent 1367e0278b
commit 76d3116f5d
3 changed files with 139 additions and 75 deletions

View File

@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::{void_ptr, Kernels, AFFINE}; use candle_metal_kernels::{void_ptr, Kernels};
use core::mem; use core::mem;
use half::{bf16, f16}; use half::{bf16, f16};
use metal; use metal;
@ -36,8 +36,7 @@ impl MetalError {
#[derive(Clone)] #[derive(Clone)]
pub struct MetalDevice { pub struct MetalDevice {
device: metal::Device, device: metal::Device,
_command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: metal::CommandBuffer,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -66,13 +65,14 @@ impl MetalDevice {
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64; let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer(size, MTLResourceOptions::empty()) self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MetalStorage { pub struct MetalStorage {
buffer: Arc<metal::Buffer>, buffer: metal::Buffer,
device: MetalDevice, device: MetalDevice,
dtype: DType, dtype: DType,
} }
@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage {
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let command_buffer = self.device.command_queue.new_owned_command_buffer();
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims(); let dims = shape.dims();
@ -123,7 +124,7 @@ impl BackendStorage for MetalStorage {
let src_length = self.buffer.length() as usize - layout.start_offset(); let src_length = self.buffer.length() as usize - layout.start_offset();
let src = self.device.new_buffer(src_length, self.dtype); let src = self.device.new_buffer(src_length, self.dtype);
let blit_encoder = self.device.command_buffer.new_blit_command_encoder(); let blit_encoder = command_buffer.new_blit_command_encoder();
blit_encoder.copy_from_buffer( blit_encoder.copy_from_buffer(
self.buffer.as_ref(), self.buffer.as_ref(),
layout.start_offset() as NSUInteger, layout.start_offset() as NSUInteger,
@ -133,7 +134,7 @@ impl BackendStorage for MetalStorage {
); );
blit_encoder.end_encoding(); blit_encoder.end_encoding();
let encoder = device.command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
@ -164,6 +165,10 @@ impl BackendStorage for MetalStorage {
encoder.dispatch_threads(grid_size, thread_group_size); encoder.dispatch_threads(grid_size, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
command_buffer.commit();
// command_buffer.wait_until_completed();
println!("Affine");
Ok(self.clone()) Ok(self.clone())
} }
@ -190,17 +195,45 @@ impl BackendStorage for MetalStorage {
} }
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims(); let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
//todo!("Implement the kernel calling"); let command_buffer = device.command_queue.new_command_buffer();
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
command_buffer.commit();
// command_buffer.wait_until_completed();
println!("Unary {:?}", B::KERNEL);
Ok(Self { Ok(Self {
buffer: Arc::new(buffer), buffer,
device, device: device.clone(),
dtype, dtype,
}) })
} }
@ -368,7 +401,7 @@ impl MetalStorage {
println!("TODO implement batched matmul for B={b}"); println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet"); // bail!("Didn't implemented strided matmul yet");
return Ok(Self { return Ok(Self {
buffer: Arc::new(out_buffer), buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
@ -380,15 +413,17 @@ impl MetalStorage {
rhs_l.is_contiguous() rhs_l.is_contiguous()
); );
return Ok(Self { return Ok(Self {
buffer: Arc::new(out_buffer), buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
} }
println!("GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>( encode_gemm::<Float32, Float32, Float32>(
&self.device, &self.device,
&self.device.command_buffer, &command_buffer,
transpose_left, transpose_left,
transpose_right, transpose_right,
&self.buffer, &self.buffer,
@ -402,13 +437,15 @@ impl MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
println!("lhs {:?} {m} {k}", self.buffer.length()); command_buffer.commit();
println!("rhs {:?} {k} {n}", rhs.buffer.length());
println!("out {:?} {m} {n}", out_buffer.length()); // println!("lhs {:?} {m} {k}", self.buffer.length());
println!("lhs {:?}", lhs_l.shape()); // println!("rhs {:?} {k} {n}", rhs.buffer.length());
// println!("out {:?} {m} {n}", out_buffer.length());
// println!("lhs {:?}", lhs_l.shape());
Ok(Self { Ok(Self {
buffer: Arc::new(out_buffer), buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}) })
@ -423,13 +460,13 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> { fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let _command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = _command_queue.new_owned_command_buffer(); // let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::init(&device).map_err(MetalError::from)?); let kernels = Arc::new(Kernels::new());
Ok(Self { Ok(Self {
device, device,
_command_queue, command_queue,
command_buffer, // command_buffer,
kernels, kernels,
}) })
} }
@ -498,7 +535,7 @@ impl BackendDevice for MetalDevice {
), ),
}; };
Ok(Self::Storage { Ok(Self::Storage {
buffer: Arc::new(buffer), buffer,
device: self.clone(), device: self.clone(),
dtype: storage.dtype(), dtype: storage.dtype(),
}) })

View File

@ -1,8 +1,7 @@
use metal::{ use metal::{
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
MTLSize, MTLSize,
}; };
use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
use std::sync::RwLock; use std::sync::RwLock;
@ -11,11 +10,18 @@ const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal"); const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal"); const UNARY: &str = include_str!("unary.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Indexing,
Unary,
}
macro_rules! unary{ macro_rules! unary{
($($name:ident),+) => { ($($name:ident),+) => {
pub mod contiguous { pub mod contiguous {
pub struct Kernel(pub &'static str); pub struct Kernel(pub(crate) &'static str);
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
@ -27,7 +33,7 @@ macro_rules! unary{
} }
pub mod strided { pub mod strided {
pub struct Kernel(pub &'static str); pub struct Kernel(pub(crate) &'static str);
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
@ -41,17 +47,17 @@ macro_rules! unary{
} }
pub mod unary { pub mod unary {
unary!(cos, sin, exp); unary!(cos, sin, exp, sqr, sqrt, neg);
} }
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| { // static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
let mut l = HashMap::new(); // let mut l = HashMap::new();
l.insert("affine", AFFINE); // l.insert("affine", AFFINE);
l.insert("indexing", INDEXING); // l.insert("indexing", INDEXING);
l.insert("unary", UNARY); // l.insert("unary", UNARY);
l // l
}); // });
//
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum MetalKernelError { pub enum MetalKernelError {
#[error("Could not lock kernel map: {0}")] #[error("Could not lock kernel map: {0}")]
@ -69,7 +75,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
} }
type KernelMap<T> = HashMap<&'static str, T>; type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = KernelMap<Library>; type Libraries = HashMap<Source, Library>;
type Functions = KernelMap<Function>; type Functions = KernelMap<Function>;
#[derive(Debug)] #[derive(Debug)]
@ -85,39 +91,42 @@ impl Kernels {
Self { libraries, funcs } Self { libraries, funcs }
} }
pub fn init(device: &Device) -> Result<Self, MetalKernelError> { // pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
let kernels = Self::new(); // let kernels = Self::new();
kernels.load_libraries(device)?; // kernels.load_libraries(device)?;
Ok(kernels) // Ok(kernels)
} // }
fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
for name in LIBRARY_SOURCES.keys() { // for name in LIBRARY_SOURCES.keys() {
self.load_library(device, name)?; // self.load_library(device, name)?;
// }
// Ok(())
// }
fn get_library_source(&self, source: Source) -> &'static str {
// LIBRARY_SOURCES.get(name).cloned()
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Indexing => INDEXING,
} }
Ok(())
}
fn get_library_source(&self, name: &'static str) -> Option<&'static str> {
LIBRARY_SOURCES.get(name).cloned()
} }
pub fn load_library( pub fn load_library(
&self, &self,
device: &Device, device: &Device,
name: &'static str, source: Source,
) -> Result<Library, MetalKernelError> { ) -> Result<Library, MetalKernelError> {
let mut libraries = self.libraries.write()?; let mut libraries = self.libraries.write()?;
if let Some(lib) = libraries.get(name) { if let Some(lib) = libraries.get(&source) {
Ok(lib.clone()) Ok(lib.clone())
} else { } else {
let source = self.get_library_source(name).ok_or_else(|| { let source_content = self.get_library_source(source);
MetalKernelError::LoadLibraryError(format!("No source found for {}", name))
})?;
let lib = device let lib = device
.new_library_with_source(source, &CompileOptions::new()) .new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
libraries.insert(name, lib.clone()); libraries.insert(source, lib.clone());
Ok(lib) Ok(lib)
} }
} }
@ -125,7 +134,7 @@ impl Kernels {
pub fn load_function( pub fn load_function(
&self, &self,
device: &Device, device: &Device,
library_name: &'static str, source: Source,
name: &'static str, name: &'static str,
) -> Result<Function, MetalKernelError> { ) -> Result<Function, MetalKernelError> {
let mut funcs = self.funcs.write()?; let mut funcs = self.funcs.write()?;
@ -133,7 +142,7 @@ impl Kernels {
Ok(func.clone()) Ok(func.clone())
} else { } else {
let func = self let func = self
.load_library(device, library_name)? .load_library(device, source)?
.get_function(name, None) .get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
funcs.insert(name, func.clone()); funcs.insert(name, func.clone());
@ -144,15 +153,16 @@ impl Kernels {
pub fn call_unary_contiguous( pub fn call_unary_contiguous(
device: &Device, device: &Device,
command_buffer: &CommandBuffer, command_buffer: &CommandBufferRef,
kernels: &Kernels, kernels: &Kernels,
kernel_name: unary::contiguous::Kernel, kernel_name: unary::contiguous::Kernel,
length: usize, length: usize,
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
assert_eq!(input.length(), output.length()); // println!("Kernel {:?}", kernel_name.0);
let func = kernels.load_function(device, "unary", kernel_name.0)?; // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func)); pipeline_state_descriptor.set_compute_function(Some(&func));
@ -188,7 +198,7 @@ pub fn call_unary_contiguous(
} }
pub fn call_unary_strided( pub fn call_unary_strided(
device: &Device, device: &Device,
command_buffer: &CommandBuffer, command_buffer: &CommandBufferRef,
kernels: &Kernels, kernels: &Kernels,
name: unary::strided::Kernel, name: unary::strided::Kernel,
input: &Buffer, input: &Buffer,
@ -197,7 +207,7 @@ pub fn call_unary_strided(
offset: usize, offset: usize,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, "unary", name.0)?; let func = kernels.load_function(device, Source::Unary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func)); pipeline_state_descriptor.set_compute_function(Some(&func));
@ -277,7 +287,7 @@ mod tests {
let device = device(); let device = device();
let kernels = Kernels::new(); let kernels = Kernels::new();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_owned_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data( let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void, v.as_ptr() as *const core::ffi::c_void,
@ -310,7 +320,7 @@ mod tests {
let device = device(); let device = device();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_owned_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data( let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void, v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64, (v.len() * core::mem::size_of::<T>()) as u64,

View File

@ -22,6 +22,9 @@ METAL_FUNC uint get_strided_index(
return strided_i; return strided_i;
} }
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
using namespace metal; using namespace metal;
@ -59,12 +62,26 @@ kernel void FN_NAME_STRIDED( \
} \ } \
} }
UNARY(cos, float, cos_float, cos_float_strided); #define UNARY_OP(NAME) \
UNARY(cos, half, cos_half, cos_half_strided); UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(sin, float, sin_float, sin_float_strided); UNARY(NAME, half, NAME##_half, NAME##_half_strided);
UNARY(sin, half, sin_half, sin_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided); BFLOAT_UNARY_OP(cos)
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided); BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
#endif #endif