mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
More scaffolding, now need to implement matmul (for precompute_cos_sin to work).
This commit is contained in:
@ -223,10 +223,9 @@ impl Device {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal rand_normal not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -250,10 +249,9 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.ones_impl(shape, dtype)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal ones not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -268,10 +266,9 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.zeros_impl(shape, dtype)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal zeros not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -284,11 +281,10 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = array.to_cpu_storage();
|
||||
// let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal storage not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
pub use candle_metal;
|
||||
use metal;
|
||||
use core::mem;
|
||||
use half::{f16, bf16};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -43,8 +45,10 @@ impl MetalDevice {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
pub buffer: metal::Buffer,
|
||||
pub device: metal::Device,
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType
|
||||
|
||||
}
|
||||
|
||||
impl BackendStorage for MetalStorage {
|
||||
@ -55,11 +59,11 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
todo!()
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
todo!()
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
@ -86,8 +90,8 @@ impl BackendStorage for MetalStorage {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
todo!()
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
@ -182,12 +186,19 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let elem_count = b * m * n;
|
||||
let dev = &self.device;
|
||||
match (self.dtype, rhs.dtype){
|
||||
(DType::F32, DType::F32) => {
|
||||
todo!("MATMUL {b} {m} {n} {k}");
|
||||
}
|
||||
_ => todo!("Unimplemented matmul for this pair")
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
@ -223,8 +234,60 @@ impl BackendDevice for MetalDevice {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
todo!("Storage")
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
||||
option
|
||||
)
|
||||
}
|
||||
};
|
||||
Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
|
Reference in New Issue
Block a user