From 4525b7b52a1e87e399b0a1c27da3beee487d1301 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:09:10 +0100 Subject: [PATCH] Initial setup --- Cargo.toml | 1 + candle-core/Cargo.toml | 3 + candle-core/src/device.rs | 54 ++++-- candle-core/src/display.rs | 2 + candle-core/src/dummy_metal_backend.rs | 201 +++++++++++++++++++++ candle-core/src/error.rs | 6 + candle-core/src/lib.rs | 6 + candle-core/src/metal_backend.rs | 236 +++++++++++++++++++++++++ candle-core/src/tensor.rs | 5 +- candle-metal/Cargo.toml | 13 ++ candle-metal/README.md | 3 + candle-metal/src/lib.rs | 1 + candle-pyo3/src/lib.rs | 1 + 13 files changed, 519 insertions(+), 13 deletions(-) create mode 100644 candle-core/src/dummy_metal_backend.rs create mode 100644 candle-core/src/metal_backend.rs create mode 100644 candle-metal/Cargo.toml create mode 100644 candle-metal/README.md create mode 100644 candle-metal/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 89ffe530..d3130105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 8e57127a..88fc9f91 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,6 +13,8 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true } +candle-metal = { path = "../candle-metal", version = "0.0.1", optional = true } +metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } @@ -39,3 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] +metal = ["dep:candle-metal", "dep:metal"] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index d566ba42..0807176e 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -1,6 +1,6 @@ use crate::backend::BackendDevice; use crate::cpu_backend::CpuDevice; -use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; +use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` /// can live on the same location (typically for cuda devices). @@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, + Metal, } #[derive(Debug, Clone)] pub enum Device { Cpu, Cuda(crate::CudaDevice), + Metal(crate::MetalDevice), } pub trait NdArray { @@ -103,14 +105,14 @@ impl NdArray for Vec { fn shape(&self) -> Result { if self.is_empty() { - crate::bail!("empty array") + bail!("empty array") } let shape0 = self[0].shape()?; let n = self.len(); for v in self.iter() { let shape = v.shape()?; if shape != shape0 { - crate::bail!("two elements have different shapes {shape:?} {shape0:?}") + bail!("two elements have different shapes {shape:?} {shape0:?}") } } Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) @@ -130,8 +132,9 @@ impl Device { pub fn set_seed(&self, seed: u64) -> Result<()> { match self { - Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed), + Self::Cpu => CpuDevice.set_seed(seed), Self::Cuda(c) => c.set_seed(seed), + Self::Metal(m) => m.set_seed(seed), } } @@ -147,21 +150,16 @@ impl Device { match self { Self::Cpu => DeviceLocation::Cpu, Self::Cuda(device) => device.location(), + Device::Metal(device) => device.location(), } } pub fn is_cpu(&self) -> bool { - match self { - Self::Cpu => true, - Self::Cuda(_) => false, - } + matches!(self, Self::Cpu) } pub fn is_cuda(&self) -> bool { - match self { - Self::Cpu => false, - Self::Cuda(_) => true, - } + matches!(self, Self::Cuda(_)) } pub fn cuda_if_available(ordinal: usize) -> Result { @@ -188,6 +186,11 @@ impl Device { let storage = device.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = device.rand_uniform(shape, dtype, lo, up)?; + // Ok(Storage::Metal(storage)) + bail!("Metal rand_uniform not implemented") + } } } @@ -216,6 +219,11 @@ impl Device { let storage = device.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = device.rand_normal(shape, dtype, mean, std)?; + // Ok(Storage::Metal(storage)) + bail!("Metal rand_normal not implemented") + } } } @@ -238,6 +246,11 @@ impl Device { let storage = device.ones_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = device.ones_impl(shape, dtype)?; + // Ok(Storage::Metal(storage)) + bail!("Metal ones not implemented") + } } } @@ -251,6 +264,11 @@ impl Device { let storage = device.zeros_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = device.zeros_impl(shape, dtype)?; + // Ok(Storage::Metal(storage)) + bail!("Metal zeros not implemented") + } } } @@ -262,6 +280,12 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = array.to_cpu_storage(); + // let storage = device.storage_from_cpu_storage(&storage)?; + // Ok(Storage::Metal(storage)) + bail!("Metal storage not implemented") + } } } @@ -273,6 +297,12 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } + Device::Metal(_device) => { + // let storage = S::to_cpu_storage_owned(data); + // let storage = device.storage_from_cpu_storage(&storage)?; + // Ok(Storage::Metal(storage)) + bail!("Metal storage_owned not implemented") + } } } } diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index b497699b..215c28f6 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -14,6 +14,7 @@ impl Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } + _ => todo!(), }; write!(f, "Tensor[")?; @@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } + crate::DeviceLocation::Metal => todo!(), }; write!( diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs new file mode 100644 index 00000000..25491097 --- /dev/null +++ b/candle-core/src/dummy_metal_backend.rs @@ -0,0 +1,201 @@ +#![allow(dead_code)] +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; + +#[derive(Debug, Clone)] +pub struct MetalDevice; + +#[derive(Debug)] +pub struct MetalStorage; + +macro_rules! fail { + () => { + unimplemented!("metal support has not been enabled, add `metal` feature to enable.") + }; +} + +impl crate::backend::BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn dtype(&self) -> DType { + fail!() + } + + fn device(&self) -> &Self::Device { + fail!() + } + + fn to_cpu_storage(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn to_dtype(&self, _: &Layout, _: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn unary_impl(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv1d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv1D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn matmul( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } +} + +impl crate::backend::BackendDevice for MetalDevice { + type Storage = MetalStorage; + fn new(_: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn location(&self) -> crate::DeviceLocation { + fail!() + } + + fn same_device(&self, _: &Self) -> bool { + fail!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } +} diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 96a2b809..caf2f89c 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -152,6 +152,9 @@ pub enum Error { #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, + #[error("the candle crate has not been built with metal support")] + NotCompiledWithMetalSupport, + #[error("cannot find tensor {path}")] CannotFindTensor { path: String }, @@ -159,6 +162,9 @@ pub enum Error { #[error(transparent)] Cuda(Box), + #[error("Metal error {0}")] + Metal(String), + #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 73830229..209ec9a7 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -49,9 +49,12 @@ mod device; pub mod display; mod dtype; mod dummy_cuda_backend; +mod dummy_metal_backend; pub mod error; mod indexer; pub mod layout; +#[cfg(feature = "accelerate")] +mod metal_backend; #[cfg(feature = "mkl")] mod mkl; pub mod npy; @@ -68,6 +71,9 @@ pub mod test_utils; pub mod utils; mod variable; +#[cfg(not(feature = "cuda"))] +pub use dummy_metal_backend::{MetalDevice, MetalStorage}; + pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, FloatDType, IntDType, WithDType}; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs new file mode 100644 index 00000000..00d236e3 --- /dev/null +++ b/candle-core/src/metal_backend.rs @@ -0,0 +1,236 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +pub use candle_metal; +use metal; + +/// Metal related errors +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("metal error")] + Metal, +} + +#[derive(Clone)] +pub struct MetalDevice { + device: metal::Device, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.device.registry_id()) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + pub fn metal_device(&self) -> &metal::DeviceRef { + self.device.as_ref() + } + + pub fn id(&self) -> u64 { + self.registry_id() + } +} + +#[derive(Debug, Clone)] +pub struct MetalStorage { + pub buffer: metal::Buffer, + pub device: metal::Device, +} + +impl BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn dtype(&self) -> DType { + todo!() + } + + fn device(&self) -> &Self::Device { + todo!() + } + + fn to_cpu_storage(&self) -> Result { + todo!() + } + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + todo!() + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + todo!() + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + todo!() + } + + fn to_dtype(&self, _: &Layout, _: DType) -> Result { + todo!() + } + + fn unary_impl(&self, _: &Layout) -> Result { + todo!() + } + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { + todo!() + } + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + todo!() + } + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv1D, + ) -> Result { + todo!() + } + + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv2D, + ) -> Result { + todo!() + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose2D, + ) -> Result { + todo!() + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + todo!() + } + + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { + todo!() + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn matmul( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + todo!() + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + todo!() + } +} + +impl BackendDevice for MetalDevice { + type Storage = MetalStorage; + + fn new(_ordinal: usize) -> Result { + todo!() + } + + fn set_seed(&self, _seed: u64) -> Result<()> { + todo!() + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Metal + } + + fn same_device(&self, _rhs: &Self) -> bool { + todo!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + todo!() + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + todo!() + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + todo!() + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + todo!() + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + todo!() + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index adcdc59d..fb3c82bb 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -6,7 +6,7 @@ use crate::op::{ }; use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; -use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; /// Unique identifier for tensors. @@ -1837,6 +1837,9 @@ impl Tensor { Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?) } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), + _ => { + bail!("not implemented yet") + } }; let op = BackpropOp::new1(self, Op::ToDevice); let tensor_ = Tensor_ { diff --git a/candle-metal/Cargo.toml b/candle-metal/Cargo.toml new file mode 100644 index 00000000..f892006e --- /dev/null +++ b/candle-metal/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "candle-metal" +version = "0.0.1" +edition = "2021" + +description = "Metal kernels for Candle" +repository = "https://github.com/huggingface/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + +[dependencies] +metal = { workspace = true, optional = true} diff --git a/candle-metal/README.md b/candle-metal/README.md new file mode 100644 index 00000000..ec923e9a --- /dev/null +++ b/candle-metal/README.md @@ -0,0 +1,3 @@ +# candle-metal-kernels + +This crate contains Metal kernels used from candle. \ No newline at end of file diff --git a/candle-metal/src/lib.rs b/candle-metal/src/lib.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/candle-metal/src/lib.rs @@ -0,0 +1 @@ + diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ddd58fbe..f8335794 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -81,6 +81,7 @@ impl PyDevice { match device { Device::Cpu => Self::Cpu, Device::Cuda(_) => Self::Cuda, + Device::Metal(_) => unimplemented!(), } }