mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Implemented cos for now.
This commit is contained in:
@ -9,7 +9,7 @@ use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::MTLResourceOptions;
|
||||
use metal::{MTLResourceOptions, Buffer};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -47,6 +47,14 @@ impl MetalDevice {
|
||||
pub fn id(&self) -> u64 {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
|
||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||
self.device.new_buffer(
|
||||
size,
|
||||
MTLResourceOptions::empty(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -106,11 +114,16 @@ impl BackendStorage for MetalStorage {
|
||||
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
// todo!()
|
||||
// TODO
|
||||
println!("TODO {:?}", B::NAME);
|
||||
Ok(self.clone())
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
todo!("Implement the kernel calling");
|
||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||
Ok(Self { buffer, device, dtype })
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -271,13 +284,10 @@ impl MetalStorage {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
if b != 1 {
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
@ -286,20 +296,12 @@ impl MetalStorage {
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||
@ -359,6 +361,7 @@ impl MetalStorage {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
|
Reference in New Issue
Block a user