Implemented cos for now.

This commit is contained in:
Nicolas Patry
2023-11-03 01:24:51 +01:00
parent 7161002a34
commit f57e3164ae
3 changed files with 172 additions and 18 deletions

View File

@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal; use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType}; use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions; use metal::{MTLResourceOptions, Buffer};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -47,6 +47,14 @@ impl MetalDevice {
pub fn id(&self) -> u64 { pub fn id(&self) -> u64 {
self.registry_id() self.registry_id()
} }
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer(
size,
MTLResourceOptions::empty(),
)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -106,11 +114,16 @@ impl BackendStorage for MetalStorage {
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype) todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
} }
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// todo!() let device = self.device().clone();
// TODO let dtype = self.dtype;
println!("TODO {:?}", B::NAME); let shape = layout.shape();
Ok(self.clone()) let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { buffer, device, dtype })
} }
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -271,13 +284,10 @@ impl MetalStorage {
let elem_count = b * m * n; let elem_count = b * m * n;
match (self.dtype, rhs.dtype) { match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => { (DType::F32, DType::F32) => {
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 { if b != 1 {
println!("TODO implement batched matmul for B={b}"); println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet"); // bail!("Didn't implemented strided matmul yet");
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
@ -286,20 +296,12 @@ impl MetalStorage {
} }
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous()); println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
} }
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
let m: u64 = m.try_into().expect("usize should fit u64"); let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64"); let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64"); let k: u64 = k.try_into().expect("usize should fit u64");
@ -359,6 +361,7 @@ impl MetalStorage {
} }
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
type Storage = MetalStorage; type Storage = MetalStorage;

View File

@ -1 +1,138 @@
use metal::{Buffer, Device, Function, Library, CompileOptions};
use std::collections::HashMap;
use std::sync::RwLock;
static UNARY: &'static str = include_str!("unary.metal");
pub enum Error {}
pub struct Kernels {
libraries: RwLock<HashMap<&'static str, Library>>,
funcs: RwLock<HashMap<String, Function>>,
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(HashMap::new());
let funcs = RwLock::new(HashMap::new());
Self { libraries, funcs }
}
pub fn call_unary(
&self,
device: &Device,
name: &str,
input: &Buffer,
output: &mut Buffer,
length: usize,
) -> Result<(), Error> {
if let Some(func) = self
.funcs
.read()
.expect("Failed to acquire kernel lock")
.get(name)
{
call_unary(func, input, output, length);
} else {
let func = self
.libraries
.write()
.expect("Failed to acquire lock")
.entry("unary")
.or_insert_with(|| {
device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library")
})
.get_function(name, None)
.expect("Could not find unary function");
self.funcs
.write()
.expect("Failed to acquire lock")
.insert(name.to_string(), func.clone());
call_unary(&func, input, output, length);
}
Ok(())
}
}
fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
todo!("Call unary");
}
#[cfg(test)]
mod tests {
use super::*;
use metal::{
ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize,
};
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32>{
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
#[test]
fn cos() {
let v = vec![1.0f32, 2.0, 3.0];
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let device = Device::system_default().unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<f32>()) as u64,
option,
);
let output = device.new_buffer((v.len() * core::mem::size_of::<f32>()) as u64, option);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
let func = library.get_function("cos", None).unwrap();
let argument_encoder = func.new_argument_encoder(0);
let arg_buffer = device.new_buffer(
argument_encoder.encoded_length(),
MTLResourceOptions::empty(),
);
argument_encoder.set_argument_buffer(&arg_buffer, 0);
argument_encoder.set_buffer(0, &input, 0);
argument_encoder.set_buffer(1, &output, 0);
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&arg_buffer), 0);
encoder.use_resource(&input, MTLResourceUsage::Read);
encoder.use_resource(&output, MTLResourceUsage::Write);
let width = 16;
let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: (v.len() as u64 + width) / width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
let results = output.read_to_vec::<f32>(v.len());
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
}
}

View File

@ -0,0 +1,14 @@
#include <metal_stdlib>
using namespace metal;
struct Input {
device float *input;
device float *output;
};
kernel void cos(device Input& args [[ buffer(0) ]], uint index [[thread_position_in_grid]])
{
args.output[index] = cos(args.input[index]);
}