mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal. * Remove tracing.
This commit is contained in:
@ -60,6 +60,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -13,6 +13,7 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -39,3 +40,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
metal = ["dep:metal"]
|
||||||
|
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
|
Metal,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
|
Metal(crate::MetalDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -128,10 +130,15 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||||
|
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
Self::Cpu => CpuDevice.set_seed(seed),
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
Self::Metal(m) => m.set_seed(seed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,21 +154,20 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
|
Device::Metal(device) => device.location(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cpu)
|
||||||
Self::Cpu => true,
|
|
||||||
Self::Cuda(_) => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cuda(_))
|
||||||
Self::Cpu => false,
|
|
||||||
Self::Cuda(_) => true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_metal(&self) -> bool {
|
||||||
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -194,6 +200,11 @@ impl Device {
|
|||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
crate::bail!("Metal rand_uniform not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,6 +239,10 @@ impl Device {
|
|||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,6 +265,10 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +282,10 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,6 +297,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = array.to_cpu_storage();
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,6 +313,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal => todo!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
223
candle-core/src/dummy_metal_backend.rs
Normal file
223
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalDevice;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MetalStorage;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Message(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for MetalError {
|
||||||
|
fn from(e: String) -> Self {
|
||||||
|
MetalError::Message(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! fail {
|
||||||
|
() => {
|
||||||
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
fn new(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _: &Self) -> bool {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -152,6 +152,9 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("the candle crate has not been built with metal support")]
|
||||||
|
NotCompiledWithMetalSupport,
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -159,6 +162,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("Metal error {0}")]
|
||||||
|
Metal(#[from] MetalError),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
@ -87,6 +88,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
|||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -184,6 +184,18 @@ pub trait CustomOp1 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
/// The function should return the gradient of the argument.
|
/// The function should return the gradient of the argument.
|
||||||
@ -219,6 +231,20 @@ pub trait CustomOp2 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -261,6 +287,22 @@ pub trait CustomOp3 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
|
|||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
|
Metal(MetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -18,6 +19,10 @@ impl Storage {
|
|||||||
let storage = storage.try_clone(layout)?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.try_clone(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,6 +30,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
|
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +38,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda(storage) => storage.dtype(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
|
Self::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +72,10 @@ impl Storage {
|
|||||||
let storage = storage.affine(layout, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +89,10 @@ impl Storage {
|
|||||||
let storage = storage.powf(layout, alpha)?;
|
let storage = storage.powf(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.powf(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +106,10 @@ impl Storage {
|
|||||||
let storage = storage.elu(layout, alpha)?;
|
let storage = storage.elu(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.elu(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +131,10 @@ impl Storage {
|
|||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -135,6 +158,10 @@ impl Storage {
|
|||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,6 +175,10 @@ impl Storage {
|
|||||||
let storage = storage.to_dtype(layout, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +192,10 @@ impl Storage {
|
|||||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||||
Ok((Self::Cuda(storage), shape))
|
Ok((Self::Cuda(storage), shape))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||||
|
Ok((Self::Metal(storage), shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +216,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,6 +244,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,6 +262,10 @@ impl Storage {
|
|||||||
let storage = storage.unary_impl::<B>(layout)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,6 +286,10 @@ impl Storage {
|
|||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -270,6 +321,10 @@ impl Storage {
|
|||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -324,6 +379,10 @@ impl Storage {
|
|||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -351,6 +410,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -375,6 +438,10 @@ impl Storage {
|
|||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,6 +460,10 @@ impl Storage {
|
|||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -406,6 +477,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -419,6 +494,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,6 +521,10 @@ impl Storage {
|
|||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||||
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -468,6 +551,10 @@ impl Storage {
|
|||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||||
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -492,6 +579,10 @@ impl Storage {
|
|||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,6 +607,10 @@ impl Storage {
|
|||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -537,6 +632,10 @@ impl Storage {
|
|||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -564,6 +663,10 @@ impl Storage {
|
|||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -583,6 +686,9 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
|||||||
};
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -529,6 +529,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1454,6 +1455,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1486,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1524,6 +1527,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1849,6 +1853,9 @@ impl Tensor {
|
|||||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
_ => {
|
||||||
|
bail!("not implemented yet")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
|
|||||||
cfg!(feature = "cuda")
|
cfg!(feature = "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn metal_is_available() -> bool {
|
||||||
|
cfg!(feature = "metal")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn with_avx() -> bool {
|
pub fn with_avx() -> bool {
|
||||||
cfg!(target_feature = "avx")
|
cfg!(target_feature = "avx")
|
||||||
}
|
}
|
||||||
|
@ -2,17 +2,28 @@ pub mod coco_classes;
|
|||||||
pub mod imagenet;
|
pub mod imagenet;
|
||||||
pub mod token_output_stream;
|
pub mod token_output_stream;
|
||||||
|
|
||||||
|
use candle::utils::{cuda_is_available, metal_is_available};
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
|
} else if cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
} else {
|
} else {
|
||||||
let device = Device::cuda_if_available(0)?;
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
if !device.is_cuda() {
|
{
|
||||||
|
println!(
|
||||||
|
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
{
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||||
}
|
}
|
||||||
Ok(device)
|
Ok(Device::Cpu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +71,13 @@ impl PyDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
|
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
enum PyDevice {
|
enum PyDevice {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda,
|
Cuda,
|
||||||
|
Metal,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyDevice {
|
impl PyDevice {
|
||||||
@ -83,6 +85,7 @@ impl PyDevice {
|
|||||||
match device {
|
match device {
|
||||||
Device::Cpu => Self::Cpu,
|
Device::Cpu => Self::Cpu,
|
||||||
Device::Cuda(_) => Self::Cuda,
|
Device::Cuda(_) => Self::Cuda,
|
||||||
|
Device::Metal(_) => Self::Metal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,6 +101,15 @@ impl PyDevice {
|
|||||||
*device = Some(d.clone());
|
*device = Some(d.clone());
|
||||||
Ok(d)
|
Ok(d)
|
||||||
}
|
}
|
||||||
|
Self::Metal => {
|
||||||
|
let mut device = METAL_DEVICE.lock().unwrap();
|
||||||
|
if let Some(device) = device.as_ref() {
|
||||||
|
return Ok(device.clone());
|
||||||
|
};
|
||||||
|
let d = Device::new_metal(0).map_err(wrap_err)?;
|
||||||
|
*device = Some(d.clone());
|
||||||
|
Ok(d)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
|
|||||||
let str = match self {
|
let str = match self {
|
||||||
PyDevice::Cpu => "cpu",
|
PyDevice::Cpu => "cpu",
|
||||||
PyDevice::Cuda => "cuda",
|
PyDevice::Cuda => "cuda",
|
||||||
|
PyDevice::Metal => "metal",
|
||||||
};
|
};
|
||||||
str.to_object(py)
|
str.to_object(py)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user