mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Initial setup
This commit is contained in:
@ -55,6 +55,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -13,6 +13,8 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
|
candle-metal = { path = "../candle-metal", version = "0.0.1", optional = true }
|
||||||
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -39,3 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
metal = ["dep:candle-metal", "dep:metal"]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendDevice;
|
use crate::backend::BackendDevice;
|
||||||
use crate::cpu_backend::CpuDevice;
|
use crate::cpu_backend::CpuDevice;
|
||||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||||
|
|
||||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||||
/// can live on the same location (typically for cuda devices).
|
/// can live on the same location (typically for cuda devices).
|
||||||
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
|
Metal,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
|
Metal(crate::MetalDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -103,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
|||||||
impl<S: NdArray> NdArray for Vec<S> {
|
impl<S: NdArray> NdArray for Vec<S> {
|
||||||
fn shape(&self) -> Result<Shape> {
|
fn shape(&self) -> Result<Shape> {
|
||||||
if self.is_empty() {
|
if self.is_empty() {
|
||||||
crate::bail!("empty array")
|
bail!("empty array")
|
||||||
}
|
}
|
||||||
let shape0 = self[0].shape()?;
|
let shape0 = self[0].shape()?;
|
||||||
let n = self.len();
|
let n = self.len();
|
||||||
for v in self.iter() {
|
for v in self.iter() {
|
||||||
let shape = v.shape()?;
|
let shape = v.shape()?;
|
||||||
if shape != shape0 {
|
if shape != shape0 {
|
||||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||||
@ -130,8 +132,9 @@ impl Device {
|
|||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
Self::Cpu => CpuDevice.set_seed(seed),
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
Self::Metal(m) => m.set_seed(seed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,21 +150,16 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
|
Device::Metal(device) => device.location(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cpu)
|
||||||
Self::Cpu => true,
|
|
||||||
Self::Cuda(_) => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cuda(_))
|
||||||
Self::Cpu => false,
|
|
||||||
Self::Cuda(_) => true,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -188,6 +186,11 @@ impl Device {
|
|||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal rand_uniform not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,6 +219,11 @@ impl Device {
|
|||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal rand_normal not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,6 +246,11 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.ones_impl(shape, dtype)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal ones not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,6 +264,11 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.zeros_impl(shape, dtype)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal zeros not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,6 +280,12 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = array.to_cpu_storage();
|
||||||
|
// let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal storage not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,6 +297,12 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = S::to_cpu_storage_owned(data);
|
||||||
|
// let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
bail!("Metal storage_owned not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal => todo!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
201
candle-core/src/dummy_metal_backend.rs
Normal file
201
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalDevice;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MetalStorage;
|
||||||
|
|
||||||
|
macro_rules! fail {
|
||||||
|
() => {
|
||||||
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
fn new(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _: &Self) -> bool {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -152,6 +152,9 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("the candle crate has not been built with metal support")]
|
||||||
|
NotCompiledWithMetalSupport,
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -159,6 +162,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("Metal error {0}")]
|
||||||
|
Metal(String),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
|
@ -49,9 +49,12 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
@ -68,6 +71,9 @@ pub mod test_utils;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::{Device, DeviceLocation};
|
pub use device::{Device, DeviceLocation};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
|
236
candle-core/src/metal_backend.rs
Normal file
236
candle-core/src/metal_backend.rs
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
pub use candle_metal;
|
||||||
|
use metal;
|
||||||
|
|
||||||
|
/// Metal related errors
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("metal error")]
|
||||||
|
Metal,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MetalDevice {
|
||||||
|
device: metal::Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MetalDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for MetalDevice {
|
||||||
|
type Target = metal::DeviceRef;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalDevice {
|
||||||
|
pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||||
|
self.device.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> u64 {
|
||||||
|
self.registry_id()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalStorage {
|
||||||
|
pub buffer: metal::Buffer,
|
||||||
|
pub device: metal::Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Ok(self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
|
||||||
|
fn new(_ordinal: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Metal
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _rhs: &Self) -> bool {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<MetalStorage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
@ -6,7 +6,7 @@ use crate::op::{
|
|||||||
};
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -1837,6 +1837,9 @@ impl Tensor {
|
|||||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
_ => {
|
||||||
|
bail!("not implemented yet")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
13
candle-metal/Cargo.toml
Normal file
13
candle-metal/Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-metal"
|
||||||
|
version = "0.0.1"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "Metal kernels for Candle"
|
||||||
|
repository = "https://github.com/huggingface/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
metal = { workspace = true, optional = true}
|
3
candle-metal/README.md
Normal file
3
candle-metal/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# candle-metal-kernels
|
||||||
|
|
||||||
|
This crate contains Metal kernels used from candle.
|
1
candle-metal/src/lib.rs
Normal file
1
candle-metal/src/lib.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
@ -81,6 +81,7 @@ impl PyDevice {
|
|||||||
match device {
|
match device {
|
||||||
Device::Cpu => Self::Cpu,
|
Device::Cpu => Self::Cpu,
|
||||||
Device::Cuda(_) => Self::Cuda,
|
Device::Cuda(_) => Self::Cuda,
|
||||||
|
Device::Metal(_) => unimplemented!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user