Implemented cos for now.

This commit is contained in:
Nicolas Patry
2023-11-03 01:24:51 +01:00
parent 7161002a34
commit f57e3164ae
3 changed files with 172 additions and 18 deletions

View File

@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions;
use metal::{MTLResourceOptions, Buffer};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -47,6 +47,14 @@ impl MetalDevice {
pub fn id(&self) -> u64 {
self.registry_id()
}
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer(
size,
MTLResourceOptions::empty(),
)
}
}
#[derive(Debug, Clone)]
@ -106,11 +114,16 @@ impl BackendStorage for MetalStorage {
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
// todo!()
// TODO
println!("TODO {:?}", B::NAME);
Ok(self.clone())
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let dtype = self.dtype;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { buffer, device, dtype })
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -271,13 +284,10 @@ impl MetalStorage {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 {
println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
@ -286,20 +296,12 @@ impl MetalStorage {
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64");
@ -359,6 +361,7 @@ impl MetalStorage {
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;