mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
8 Commits
0.6.0
...
metal2-tmp
Author | SHA1 | Date | |
---|---|---|---|
d9c1f7e201 | |||
315ba4cf0c | |||
915f0e5b69 | |||
9975f2b239 | |||
d7cc660c68 | |||
c54ed0ab48 | |||
af5e77f409 | |||
8cf39d27ce |
@ -13,6 +13,7 @@ members = [
|
|||||||
exclude = [
|
exclude = [
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
|
"candle-metal-kernels",
|
||||||
"candle-onnx",
|
"candle-onnx",
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
@ -59,6 +60,8 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -13,6 +13,8 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -39,3 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||||
|
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
|
Metal { gpu_id: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
|
Metal(crate::MetalDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -128,10 +130,15 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||||
|
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
Self::Cpu => CpuDevice.set_seed(seed),
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
Self::Metal(m) => m.set_seed(seed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +146,7 @@ impl Device {
|
|||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -147,21 +155,20 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
|
Device::Metal(device) => device.location(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cpu)
|
||||||
Self::Cpu => true,
|
|
||||||
Self::Cuda(_) => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cuda(_))
|
||||||
Self::Cpu => false,
|
}
|
||||||
Self::Cuda(_) => true,
|
|
||||||
}
|
pub fn is_metal(&self) -> bool {
|
||||||
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -194,6 +201,11 @@ impl Device {
|
|||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
crate::bail!("Metal rand_uniform not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,6 +240,10 @@ impl Device {
|
|||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,6 +266,10 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +283,10 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,6 +298,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = array.to_cpu_storage();
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,6 +314,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,9 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -476,6 +479,9 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
223
candle-core/src/dummy_metal_backend.rs
Normal file
223
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalDevice;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MetalStorage;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Message(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for MetalError {
|
||||||
|
fn from(e: String) -> Self {
|
||||||
|
MetalError::Message(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! fail {
|
||||||
|
() => {
|
||||||
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
fn new(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _: &Self) -> bool {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -152,6 +152,9 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("the candle crate has not been built with metal support")]
|
||||||
|
NotCompiledWithMetalSupport,
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -159,6 +162,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("Metal error {0}")]
|
||||||
|
Metal(#[from] MetalError),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
|
@ -49,9 +49,12 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
@ -87,6 +90,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
|||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
859
candle-core/src/metal_backend.rs
Normal file
859
candle-core/src/metal_backend.rs
Normal file
@ -0,0 +1,859 @@
|
|||||||
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
use candle_metal_kernels;
|
||||||
|
use candle_metal_kernels::Kernels;
|
||||||
|
use core::mem;
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use metal;
|
||||||
|
use metal::mps::matrix::encode_gemm;
|
||||||
|
use metal::mps::Float32;
|
||||||
|
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Metal related errors
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Message(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for MetalError {
|
||||||
|
fn from(e: String) -> Self {
|
||||||
|
MetalError::Message(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MetalDevice {
|
||||||
|
device: metal::Device,
|
||||||
|
command_queue: metal::CommandQueue,
|
||||||
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MetalDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for MetalDevice {
|
||||||
|
type Target = metal::DeviceRef;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalDevice {
|
||||||
|
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||||
|
// self.device.as_ref()
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn id(&self) -> NSUInteger {
|
||||||
|
self.registry_id()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
|
&self.command_queue
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kernels(&self) -> &Kernels {
|
||||||
|
&self.kernels
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
|
// debug!("Allocate 1 - buffer size {size}");
|
||||||
|
self.device
|
||||||
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalStorage {
|
||||||
|
buffer: metal::Buffer,
|
||||||
|
device: MetalDevice,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Ok(self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
// TODO Is this necessary
|
||||||
|
// self.buffer.synchronize();
|
||||||
|
match self.dtype {
|
||||||
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
||||||
|
)),
|
||||||
|
DType::U32 => Ok(CpuStorage::U32(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||||
|
)),
|
||||||
|
DType::I64 => Ok(CpuStorage::I64(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||||
|
)),
|
||||||
|
DType::F16 => Ok(CpuStorage::F16(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||||
|
)),
|
||||||
|
DType::BF16 => Ok(CpuStorage::BF16(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||||
|
)),
|
||||||
|
DType::F32 => Ok(CpuStorage::F32(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||||
|
)),
|
||||||
|
DType::F64 => Ok(CpuStorage::F64(
|
||||||
|
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
|
let device = self.device().clone();
|
||||||
|
|
||||||
|
let shape = layout.shape();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let dtype = self.dtype;
|
||||||
|
|
||||||
|
assert!(layout.is_contiguous());
|
||||||
|
assert_eq!(dtype, DType::F32);
|
||||||
|
|
||||||
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
candle_metal_kernels::call_affine(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
el,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
return Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
|
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
|
||||||
|
assert!(sum_dims.len() == 1);
|
||||||
|
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||||
|
assert!(layout.is_contiguous());
|
||||||
|
let device = self.device.clone();
|
||||||
|
let src_stride = layout.stride();
|
||||||
|
let src_dims = layout.shape().dims();
|
||||||
|
let src_el: usize = src_dims.iter().product();
|
||||||
|
// Source dims and strides with the sum dims at the end.
|
||||||
|
let mut dims = vec![];
|
||||||
|
let mut stride = vec![];
|
||||||
|
let mut dst_el: usize = 1;
|
||||||
|
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||||
|
if !sum_dims.contains(&dim_idx) {
|
||||||
|
dst_el *= d;
|
||||||
|
dims.push(d);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for &dim_idx in sum_dims.iter() {
|
||||||
|
dims.push(src_dims[dim_idx]);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The reduction loop requires the shared array to be properly initialized and for
|
||||||
|
// this we want the number of threads to be a power of two.
|
||||||
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||||
|
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||||
|
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||||
|
_ => todo!("Reduce op for non float"),
|
||||||
|
};
|
||||||
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
}
|
||||||
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
name,
|
||||||
|
src_el,
|
||||||
|
dst_el,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
|
let device = self.device();
|
||||||
|
let shape = layout.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
|
if layout.is_contiguous() {
|
||||||
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
|
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
todo!(
|
||||||
|
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||||
|
self.dtype,
|
||||||
|
dtype
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
// command_buffer.wait_until_scheduled();
|
||||||
|
// debug!(
|
||||||
|
// "cast {:?} - {:?} - {:?}",
|
||||||
|
// dtype,
|
||||||
|
// self.buffer.length(),
|
||||||
|
// buffer.length()
|
||||||
|
// );
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
|
let device = self.device();
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let shape = layout.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
|
if layout.is_contiguous() {
|
||||||
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
|
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||||
|
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||||
|
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||||
|
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||||
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let device = self.device();
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let shape = lhs_l.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
|
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||||
|
use candle_metal_kernels::binary::contiguous;
|
||||||
|
|
||||||
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
|
("add", DType::F32) => contiguous::add::FLOAT,
|
||||||
|
("badd", DType::F32) => contiguous::add::FLOAT,
|
||||||
|
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||||
|
("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||||
|
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
|
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
|
("div", DType::F32) => contiguous::div::FLOAT,
|
||||||
|
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||||
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
&self.buffer,
|
||||||
|
&rhs.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
use candle_metal_kernels::binary::strided;
|
||||||
|
|
||||||
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
|
("badd", DType::F32) => strided::add::FLOAT,
|
||||||
|
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||||
|
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||||
|
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||||
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_binary_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
lhs_l.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
&lhs_l.stride(),
|
||||||
|
lhs_l.start_offset(),
|
||||||
|
&rhs.buffer,
|
||||||
|
&rhs_l.stride(),
|
||||||
|
rhs_l.start_offset(),
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(
|
||||||
|
&self,
|
||||||
|
layout: &Layout,
|
||||||
|
t: &Self,
|
||||||
|
t_l: &Layout,
|
||||||
|
f: &Self,
|
||||||
|
f_l: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let device = self.device.clone();
|
||||||
|
let shape = t_l.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let dtype = t.dtype;
|
||||||
|
let mut buffer = self.device.new_buffer(el, dtype);
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
"where_u8_f32",
|
||||||
|
&dims,
|
||||||
|
&self.buffer,
|
||||||
|
(layout.stride(), layout.start_offset()),
|
||||||
|
&t.buffer,
|
||||||
|
(&t_l.stride(), t_l.start_offset()),
|
||||||
|
&f.buffer,
|
||||||
|
(&f_l.stride(), f_l.start_offset()),
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
|
assert!(src_l.is_contiguous());
|
||||||
|
assert!(ids_l.is_contiguous());
|
||||||
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let ids_el = ids_l.shape().elem_count();
|
||||||
|
let dst_el = ids_el * left_size * right_size;
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let device = self.device();
|
||||||
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
|
let out = self.to_cpu_storage().unwrap();
|
||||||
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
|
};
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
// println!("INDEX SELECT");
|
||||||
|
candle_metal_kernels::call_index_select(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
src_l.dims(),
|
||||||
|
ids_el,
|
||||||
|
dim,
|
||||||
|
&self.buffer,
|
||||||
|
&ids.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
// Create descriptors
|
||||||
|
use metal::mps::matrix::*;
|
||||||
|
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
||||||
|
let size = core::mem::size_of::<f32>() as NSUInteger;
|
||||||
|
|
||||||
|
let elem_count = b * m * n;
|
||||||
|
|
||||||
|
let lhs_stride = lhs_l.stride();
|
||||||
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
|
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
|
false
|
||||||
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
|
false
|
||||||
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
// println!("{transpose_left} {transpose_right}");
|
||||||
|
|
||||||
|
let b = b as NSUInteger;
|
||||||
|
let m = m as NSUInteger;
|
||||||
|
let n = n as NSUInteger;
|
||||||
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
|
let left_descriptor = if transpose_left {
|
||||||
|
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||||
|
};
|
||||||
|
let right_descriptor = if transpose_right {
|
||||||
|
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||||
|
};
|
||||||
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
|
// Create matrix objects
|
||||||
|
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let alpha = 1.0f64;
|
||||||
|
let beta = 0.0f64;
|
||||||
|
// Create kernel
|
||||||
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
|
&self.device,
|
||||||
|
transpose_left,
|
||||||
|
transpose_right,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
matrix_multiplication.set_batch_size(b);
|
||||||
|
|
||||||
|
// Encode kernel to command buffer
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
|
command_buffer,
|
||||||
|
&left_matrix,
|
||||||
|
&right_matrix,
|
||||||
|
&result_matrix,
|
||||||
|
);
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
// let left = self.buffer.read_to_vec::<f32>(10);
|
||||||
|
// let right = rhs.buffer.read_to_vec::<f32>(10);
|
||||||
|
// let out = out_buffer.read_to_vec::<f32>(40);
|
||||||
|
// todo!("Out {left:?} {right:?} {out:?}");
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer: out_buffer,
|
||||||
|
device: self.device.clone(),
|
||||||
|
dtype: self.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
|
let src_shape = src_l.shape();
|
||||||
|
let el_count = src_shape.elem_count();
|
||||||
|
if el_count == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
let kernel_name = match self.dtype {
|
||||||
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
|
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
src_l.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
&src_l.stride(),
|
||||||
|
src_l.start_offset(),
|
||||||
|
&mut dst.buffer,
|
||||||
|
dst_offset,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
// todo!("Output {:?}", dst.buffer.read_to_vec::<f32>(10));
|
||||||
|
// }
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalStorage {
|
||||||
|
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buffer(&self) -> &Buffer {
|
||||||
|
&self.buffer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
|
||||||
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
|
// let capture = metal::CaptureManager::shared();
|
||||||
|
// let descriptor = metal::CaptureDescriptor::new();
|
||||||
|
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
|
// descriptor.set_capture_device(&device);
|
||||||
|
// let mut dir = std::env::current_dir()?;
|
||||||
|
// dir.push("out.gputrace");
|
||||||
|
// descriptor.set_output_url(dir);
|
||||||
|
|
||||||
|
// capture
|
||||||
|
// .start_capture(&descriptor)
|
||||||
|
// .map_err(MetalError::from)?;
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||||
|
let kernels = Arc::new(Kernels::new());
|
||||||
|
Ok(Self {
|
||||||
|
device,
|
||||||
|
command_queue,
|
||||||
|
// command_buffer,
|
||||||
|
kernels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
|
todo!("set_seed")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Metal {
|
||||||
|
gpu_id: self.registry_id() as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
|
self.device.registry_id() == rhs.device.registry_id()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
|
// TODO Is there a faster way ?
|
||||||
|
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||||
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
|
// TODO Is there a faster way ?
|
||||||
|
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||||
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||||
|
let buffer = match storage {
|
||||||
|
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||||
|
option,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
// TODO is that necessary ?
|
||||||
|
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||||
|
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||||
|
Ok(Self::Storage {
|
||||||
|
buffer,
|
||||||
|
device: self.clone(),
|
||||||
|
dtype: storage.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
stddev: f64,
|
||||||
|
) -> Result<Self::Storage> {
|
||||||
|
// TODO is there a better way ?
|
||||||
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||||
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
stddev: f64,
|
||||||
|
) -> Result<Self::Storage> {
|
||||||
|
// TODO is there a better way ?
|
||||||
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||||
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -184,6 +184,18 @@ pub trait CustomOp1 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
/// The function should return the gradient of the argument.
|
/// The function should return the gradient of the argument.
|
||||||
@ -219,6 +231,20 @@ pub trait CustomOp2 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -261,6 +287,22 @@ pub trait CustomOp3 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
|
|||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
|
Metal(MetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -18,6 +19,10 @@ impl Storage {
|
|||||||
let storage = storage.try_clone(layout)?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.try_clone(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,6 +30,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
|
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +38,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda(storage) => storage.dtype(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
|
Self::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +72,10 @@ impl Storage {
|
|||||||
let storage = storage.affine(layout, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +89,10 @@ impl Storage {
|
|||||||
let storage = storage.powf(layout, alpha)?;
|
let storage = storage.powf(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.powf(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +106,10 @@ impl Storage {
|
|||||||
let storage = storage.elu(layout, alpha)?;
|
let storage = storage.elu(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.elu(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +131,10 @@ impl Storage {
|
|||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -135,6 +158,10 @@ impl Storage {
|
|||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,6 +175,10 @@ impl Storage {
|
|||||||
let storage = storage.to_dtype(layout, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +192,10 @@ impl Storage {
|
|||||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||||
Ok((Self::Cuda(storage), shape))
|
Ok((Self::Cuda(storage), shape))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||||
|
Ok((Self::Metal(storage), shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +216,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,6 +244,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,6 +262,10 @@ impl Storage {
|
|||||||
let storage = storage.unary_impl::<B>(layout)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,6 +286,10 @@ impl Storage {
|
|||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -270,6 +321,10 @@ impl Storage {
|
|||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -324,6 +379,10 @@ impl Storage {
|
|||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -351,6 +410,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -375,6 +438,10 @@ impl Storage {
|
|||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,6 +460,10 @@ impl Storage {
|
|||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -406,6 +477,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -419,6 +494,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,6 +521,10 @@ impl Storage {
|
|||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||||
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -468,6 +551,10 @@ impl Storage {
|
|||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||||
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -492,6 +579,10 @@ impl Storage {
|
|||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,6 +607,10 @@ impl Storage {
|
|||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -537,6 +632,10 @@ impl Storage {
|
|||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -564,6 +663,10 @@ impl Storage {
|
|||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -583,6 +686,9 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
|||||||
};
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -157,6 +157,8 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
|||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let dtype = storage.dtype();
|
let dtype = storage.dtype();
|
||||||
let device = storage.device();
|
let device = storage.device();
|
||||||
|
let shape = shape.into();
|
||||||
|
// println!("{:?} {storage:?}", shape);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(RwLock::new(storage)),
|
storage: Arc::new(RwLock::new(storage)),
|
||||||
@ -166,7 +168,11 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
|||||||
dtype,
|
dtype,
|
||||||
device,
|
device,
|
||||||
};
|
};
|
||||||
Tensor(Arc::new(tensor_))
|
let result = Tensor(Arc::new(tensor_));
|
||||||
|
// todo!(" from_storage");
|
||||||
|
// let result = result.to_device(&Device::Cpu).unwrap();
|
||||||
|
// todo!(" {result}");
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
@ -529,6 +535,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1454,6 +1461,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1492,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1524,6 +1533,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1841,7 +1851,11 @@ impl Tensor {
|
|||||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
|
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||||
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
|
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -1849,6 +1863,9 @@ impl Tensor {
|
|||||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
_ => {
|
||||||
|
bail!("not implemented yet")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
|
|||||||
cfg!(feature = "cuda")
|
cfg!(feature = "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn metal_is_available() -> bool {
|
||||||
|
cfg!(feature = "metal")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn with_avx() -> bool {
|
pub fn with_avx() -> bool {
|
||||||
cfg!(target_feature = "avx")
|
cfg!(target_feature = "avx")
|
||||||
}
|
}
|
||||||
|
@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
|
println!("{tokens:?}");
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0..1 {
|
||||||
if tokens.len() >= config.seq_len {
|
if tokens.len() >= config.seq_len {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
|
// println!("Input {}", input);
|
||||||
|
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, index_pos)?;
|
||||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||||
|
@ -2,17 +2,28 @@ pub mod coco_classes;
|
|||||||
pub mod imagenet;
|
pub mod imagenet;
|
||||||
pub mod token_output_stream;
|
pub mod token_output_stream;
|
||||||
|
|
||||||
|
use candle::utils::{cuda_is_available, metal_is_available};
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
|
} else if cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
} else {
|
} else {
|
||||||
let device = Device::cuda_if_available(0)?;
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
if !device.is_cuda() {
|
{
|
||||||
|
println!(
|
||||||
|
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
{
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||||
}
|
}
|
||||||
Ok(device)
|
Ok(Device::Cpu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
21
candle-metal-kernels/Cargo.toml
Normal file
21
candle-metal-kernels/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-metal-kernels"
|
||||||
|
version = "0.3.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "CUDA kernels for Candle"
|
||||||
|
repository = "https://github.com/huggingface/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||||
|
once_cell = "1.18.0"
|
||||||
|
thiserror = "1"
|
||||||
|
tracing = "0.1.37"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
|
rand = "0.8.5"
|
3
candle-metal-kernels/README.md
Normal file
3
candle-metal-kernels/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# candle-metal-kernels
|
||||||
|
|
||||||
|
This crate contains Metal kernels used from candle.
|
75
candle-metal-kernels/examples/affine.rs
Normal file
75
candle-metal-kernels/examples/affine.rs
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
use candle_metal_kernels::{call_affine, Kernels};
|
||||||
|
use metal::objc::rc::autoreleasepool;
|
||||||
|
use metal::{Device, MTLResourceOptions};
|
||||||
|
use rand;
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let device = Device::system_default().unwrap();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
|
||||||
|
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||||
|
let f32_10k = (0..10000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let f32_100k = (0..100000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||||
|
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||||
|
);
|
||||||
|
|
||||||
|
// f32
|
||||||
|
run_affine_bench(&device, &kernels, &f32_1k);
|
||||||
|
run_affine_bench(&device, &kernels, &f32_10k);
|
||||||
|
run_affine_bench(&device, &kernels, &f32_100k);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let iterations = 10000;
|
||||||
|
let input = device.new_buffer_with_data(
|
||||||
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
|
core::mem::size_of_val(v) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||||
|
|
||||||
|
let mul: f32 = 1.2345;
|
||||||
|
let add: f32 = 2.3456;
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_affine(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
mul,
|
||||||
|
add,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
"affine",
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
182
candle-metal-kernels/examples/binary.rs
Normal file
182
candle-metal-kernels/examples/binary.rs
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use metal::objc::rc::autoreleasepool;
|
||||||
|
use metal::{Device, MTLResourceOptions};
|
||||||
|
use rand;
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let device = Device::system_default().unwrap();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
|
||||||
|
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||||
|
let f32_10k = (0..10000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let f32_100k = (0..100000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let f16_1k = f16_map(&f32_1k);
|
||||||
|
let f16_10k = f16_map(&f32_10k);
|
||||||
|
let f16_100k = f16_map(&f32_100k);
|
||||||
|
|
||||||
|
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let bf16_1k = bf16_map(&f32_1k);
|
||||||
|
let bf16_10k = bf16_map(&f32_10k);
|
||||||
|
let bf16_100k = bf16_map(&f32_100k);
|
||||||
|
|
||||||
|
let f32_ckernels = [
|
||||||
|
binary::contiguous::add::FLOAT,
|
||||||
|
binary::contiguous::sub::FLOAT,
|
||||||
|
binary::contiguous::mul::FLOAT,
|
||||||
|
binary::contiguous::div::FLOAT,
|
||||||
|
];
|
||||||
|
let f32_skernels = [
|
||||||
|
binary::strided::add::FLOAT,
|
||||||
|
binary::strided::sub::FLOAT,
|
||||||
|
binary::strided::mul::FLOAT,
|
||||||
|
binary::strided::div::FLOAT,
|
||||||
|
];
|
||||||
|
let f16_ckernels = [
|
||||||
|
binary::contiguous::add::HALF,
|
||||||
|
binary::contiguous::sub::HALF,
|
||||||
|
binary::contiguous::mul::HALF,
|
||||||
|
binary::contiguous::div::HALF,
|
||||||
|
];
|
||||||
|
let f16_skernels = [
|
||||||
|
binary::strided::add::HALF,
|
||||||
|
binary::strided::sub::HALF,
|
||||||
|
binary::strided::mul::HALF,
|
||||||
|
binary::strided::div::HALF,
|
||||||
|
];
|
||||||
|
let bf16_ckernels = [
|
||||||
|
binary::contiguous::add::BFLOAT,
|
||||||
|
binary::contiguous::sub::BFLOAT,
|
||||||
|
binary::contiguous::mul::BFLOAT,
|
||||||
|
binary::contiguous::div::BFLOAT,
|
||||||
|
];
|
||||||
|
let bf16_skernels = [
|
||||||
|
binary::strided::add::BFLOAT,
|
||||||
|
binary::strided::sub::BFLOAT,
|
||||||
|
binary::strided::mul::BFLOAT,
|
||||||
|
binary::strided::div::BFLOAT,
|
||||||
|
];
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||||
|
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||||
|
);
|
||||||
|
|
||||||
|
// f32
|
||||||
|
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||||
|
|
||||||
|
// f16
|
||||||
|
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||||
|
|
||||||
|
// bf16
|
||||||
|
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||||
|
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_binary_bench<T: Clone>(
|
||||||
|
device: &Device,
|
||||||
|
kernels: &Kernels,
|
||||||
|
v: &[T],
|
||||||
|
contiguous: [binary::contiguous::Kernel; 4],
|
||||||
|
strided: [binary::strided::Kernel; 4],
|
||||||
|
) {
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let iterations = 1000;
|
||||||
|
let input = device.new_buffer_with_data(
|
||||||
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
|
core::mem::size_of_val(v) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
for kernel_name in contiguous {
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_binary_contiguous(
|
||||||
|
device,
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
kernel_name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
kernel_name.to_string(),
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strided
|
||||||
|
let shape = vec![2, 5_000];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
for kernel_name in strided {
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_binary_strided(
|
||||||
|
device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
kernel_name,
|
||||||
|
&shape,
|
||||||
|
&input,
|
||||||
|
&strides,
|
||||||
|
offset,
|
||||||
|
&input,
|
||||||
|
&strides,
|
||||||
|
offset,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
kernel_name.to_string(),
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
84
candle-metal-kernels/examples/cast.rs
Normal file
84
candle-metal-kernels/examples/cast.rs
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
use candle_metal_kernels::{call_cast_contiguous, Kernels};
|
||||||
|
use metal::objc::rc::autoreleasepool;
|
||||||
|
use metal::{Device, MTLResourceOptions};
|
||||||
|
use rand;
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let device = Device::system_default().unwrap();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
|
||||||
|
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||||
|
let f32_10k = (0..10000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let f32_100k = (0..100000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let contiguous_kernels = ["cast_u32_f32"];
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||||
|
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||||
|
);
|
||||||
|
|
||||||
|
// f32
|
||||||
|
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
|
||||||
|
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
|
||||||
|
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cast_bench<T: Clone>(
|
||||||
|
device: &Device,
|
||||||
|
kernels: &Kernels,
|
||||||
|
v: &[T],
|
||||||
|
contiguous: &[&'static str],
|
||||||
|
) {
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let iterations = 1000;
|
||||||
|
let input = device.new_buffer_with_data(
|
||||||
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
|
core::mem::size_of_val(v) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
for kernel_name in contiguous {
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_cast_contiguous(
|
||||||
|
device,
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
kernel_name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
kernel_name.to_string(),
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strided?
|
||||||
|
}
|
197
candle-metal-kernels/examples/unary.rs
Normal file
197
candle-metal-kernels/examples/unary.rs
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use metal::objc::rc::autoreleasepool;
|
||||||
|
use metal::{Device, MTLResourceOptions};
|
||||||
|
use rand;
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let device = Device::system_default().unwrap();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
|
||||||
|
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||||
|
let f32_10k = (0..10000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let f32_100k = (0..100000)
|
||||||
|
.map(|_| rand::random::<f32>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let f16_1k = f16_map(&f32_1k);
|
||||||
|
let f16_10k = f16_map(&f32_10k);
|
||||||
|
let f16_100k = f16_map(&f32_100k);
|
||||||
|
|
||||||
|
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let bf16_1k = bf16_map(&f32_1k);
|
||||||
|
let bf16_10k = bf16_map(&f32_10k);
|
||||||
|
let bf16_100k = bf16_map(&f32_100k);
|
||||||
|
|
||||||
|
let f32_ckernels = [
|
||||||
|
unary::contiguous::sin::FLOAT,
|
||||||
|
unary::contiguous::cos::FLOAT,
|
||||||
|
unary::contiguous::exp::FLOAT,
|
||||||
|
unary::contiguous::sqr::FLOAT,
|
||||||
|
unary::contiguous::sqrt::FLOAT,
|
||||||
|
unary::contiguous::neg::FLOAT,
|
||||||
|
unary::contiguous::copy::FLOAT,
|
||||||
|
];
|
||||||
|
let f32_skernels = [
|
||||||
|
unary::strided::sin::FLOAT,
|
||||||
|
unary::strided::cos::FLOAT,
|
||||||
|
unary::strided::exp::FLOAT,
|
||||||
|
unary::strided::sqr::FLOAT,
|
||||||
|
unary::strided::sqrt::FLOAT,
|
||||||
|
unary::strided::neg::FLOAT,
|
||||||
|
unary::strided::copy::FLOAT,
|
||||||
|
];
|
||||||
|
let f16_ckernels = [
|
||||||
|
unary::contiguous::sin::HALF,
|
||||||
|
unary::contiguous::cos::HALF,
|
||||||
|
unary::contiguous::exp::HALF,
|
||||||
|
unary::contiguous::sqr::HALF,
|
||||||
|
unary::contiguous::sqrt::HALF,
|
||||||
|
unary::contiguous::neg::HALF,
|
||||||
|
unary::contiguous::copy::HALF,
|
||||||
|
];
|
||||||
|
let f16_skernels = [
|
||||||
|
unary::strided::sin::HALF,
|
||||||
|
unary::strided::cos::HALF,
|
||||||
|
unary::strided::exp::HALF,
|
||||||
|
unary::strided::sqr::HALF,
|
||||||
|
unary::strided::sqrt::HALF,
|
||||||
|
unary::strided::neg::HALF,
|
||||||
|
unary::strided::copy::HALF,
|
||||||
|
];
|
||||||
|
let bf16_ckernels = [
|
||||||
|
unary::contiguous::sin::BFLOAT,
|
||||||
|
unary::contiguous::cos::BFLOAT,
|
||||||
|
unary::contiguous::exp::BFLOAT,
|
||||||
|
unary::contiguous::sqr::BFLOAT,
|
||||||
|
unary::contiguous::sqrt::BFLOAT,
|
||||||
|
unary::contiguous::neg::BFLOAT,
|
||||||
|
unary::contiguous::copy::BFLOAT,
|
||||||
|
];
|
||||||
|
let bf16_skernels = [
|
||||||
|
unary::strided::sin::BFLOAT,
|
||||||
|
unary::strided::cos::BFLOAT,
|
||||||
|
unary::strided::exp::BFLOAT,
|
||||||
|
unary::strided::sqr::BFLOAT,
|
||||||
|
unary::strided::sqrt::BFLOAT,
|
||||||
|
unary::strided::neg::BFLOAT,
|
||||||
|
unary::strided::copy::BFLOAT,
|
||||||
|
];
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||||
|
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||||
|
);
|
||||||
|
|
||||||
|
// f32
|
||||||
|
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||||
|
|
||||||
|
// f16
|
||||||
|
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||||
|
|
||||||
|
// bf16
|
||||||
|
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||||
|
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_unary_bench<T: Clone>(
|
||||||
|
device: &Device,
|
||||||
|
kernels: &Kernels,
|
||||||
|
v: &[T],
|
||||||
|
contiguous: [unary::contiguous::Kernel; 7],
|
||||||
|
strided: [unary::strided::Kernel; 7],
|
||||||
|
) {
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let iterations = 10000;
|
||||||
|
let input = device.new_buffer_with_data(
|
||||||
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
|
core::mem::size_of_val(v) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
for kernel_name in contiguous {
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_unary_contiguous(
|
||||||
|
device,
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
kernel_name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
kernel_name.to_string(),
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strided
|
||||||
|
let shape = vec![2, 5_000];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
for kernel_name in strided {
|
||||||
|
let total_time = autoreleasepool(|| {
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..iterations {
|
||||||
|
call_unary_strided(
|
||||||
|
device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
kernel_name,
|
||||||
|
&shape,
|
||||||
|
&input,
|
||||||
|
&strides,
|
||||||
|
offset,
|
||||||
|
&mut output,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
start.elapsed()
|
||||||
|
});
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
|
kernel_name.to_string(),
|
||||||
|
v.len(),
|
||||||
|
iterations,
|
||||||
|
total_time,
|
||||||
|
total_time / iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
43
candle-metal-kernels/src/affine.metal
Normal file
43
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define AFFINE(FN_NAME, TYPENAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant float &mul, \
|
||||||
|
constant float &add, \
|
||||||
|
device const TYPENAME *input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint id [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (id >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
const TYPENAME m = TYPENAME(mul); \
|
||||||
|
const TYPENAME a = TYPENAME(add); \
|
||||||
|
output[id] = input[id] * m + a; \
|
||||||
|
} \
|
||||||
|
|
||||||
|
AFFINE(affine_float, float)
|
||||||
|
AFFINE(affine_half, half)
|
||||||
|
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
AFFINE(affine_bfloat, bfloat);
|
||||||
|
#endif
|
72
candle-metal-kernels/src/binary.metal
Normal file
72
candle-metal-kernels/src/binary.metal
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
device const TYPENAME *left, \
|
||||||
|
device const TYPENAME *right, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (thread_position_in_grid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = left[thread_position_in_grid]; \
|
||||||
|
TYPENAME y = right[thread_position_in_grid]; \
|
||||||
|
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||||
|
}\
|
||||||
|
kernel void FN_NAME_STRIDED( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *left_strides, \
|
||||||
|
constant size_t *right_strides, \
|
||||||
|
device const TYPENAME *left, \
|
||||||
|
device const TYPENAME *right, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (thread_position_in_grid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||||
|
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||||
|
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BINARY_OP(FN, NAME) \
|
||||||
|
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||||
|
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||||
|
|
||||||
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
|
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||||
|
|
||||||
|
|
||||||
|
BINARY_OP(x + y, add)
|
||||||
|
BINARY_OP(x - y, sub)
|
||||||
|
BINARY_OP(x * y, mul)
|
||||||
|
BINARY_OP(x / y, div)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
BFLOAT_BINARY_OP(x + y, add)
|
||||||
|
BFLOAT_BINARY_OP(x - y, sub)
|
||||||
|
BFLOAT_BINARY_OP(x * y, mul)
|
||||||
|
BFLOAT_BINARY_OP(x / y, div)
|
||||||
|
#endif
|
51
candle-metal-kernels/src/cast.metal
Normal file
51
candle-metal-kernels/src/cast.metal
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
device const LEFT_TYPENAME *input, \
|
||||||
|
device RIGHT_TYPENAME *output, \
|
||||||
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (thread_position_in_grid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||||
|
} \
|
||||||
|
kernel void FN_NAME_STRIDED( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
device const LEFT_TYPENAME *input, \
|
||||||
|
device RIGHT_TYPENAME *output, \
|
||||||
|
uint i [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (i >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
#endif
|
102
candle-metal-kernels/src/indexing.metal
Normal file
102
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &dst_size, \
|
||||||
|
constant size_t &left_size, \
|
||||||
|
constant size_t &src_dim_size, \
|
||||||
|
constant size_t &right_size, \
|
||||||
|
constant size_t &ids_size, \
|
||||||
|
const device TYPENAME *input, \
|
||||||
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint gid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (gid >= dst_size) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
const size_t id_i = gid / right_size / left_size; \
|
||||||
|
const size_t right_rank_i = gid % right_size; \
|
||||||
|
const size_t left_rank_i = gid % left_size; \
|
||||||
|
/* \
|
||||||
|
// Force prevent out of bounds indexing \
|
||||||
|
// since there doesn't seem to be a good way to force crash \
|
||||||
|
// No need to check for zero we're only allowing unsized. \
|
||||||
|
*/ \
|
||||||
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||||
|
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||||
|
output[gid] = input[src_i]; \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
void index_add(
|
||||||
|
device I *ids [[buffer(0)]],
|
||||||
|
device T *inp [[buffer(1)]],
|
||||||
|
device T *out [[buffer(2)]],
|
||||||
|
|
||||||
|
constant uint &ids_dim_size,
|
||||||
|
constant uint &left_size,
|
||||||
|
constant uint &dst_dim_size,
|
||||||
|
constant uint &right_size,
|
||||||
|
|
||||||
|
uint gid [[ thread_position_in_grid ]] \
|
||||||
|
) {
|
||||||
|
|
||||||
|
if (gid >= left_size * right_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i = gid;
|
||||||
|
const uint pre = i / right_size;
|
||||||
|
const uint post = i % right_size;
|
||||||
|
|
||||||
|
for (uint j = 0; j < ids_dim_size; j++) {
|
||||||
|
const uint idx = ids[j];
|
||||||
|
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
|
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||||
|
device TYPENAME *inp [[buffer(1)]], \
|
||||||
|
device TYPENAME *out [[buffer(2)]], \
|
||||||
|
constant uint &ids_dim_size, \
|
||||||
|
constant uint &left_size, \
|
||||||
|
constant uint &dst_dim_size, \
|
||||||
|
constant uint &right_size, \
|
||||||
|
uint gid [[ thread_position_in_grid ]] \
|
||||||
|
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
|
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||||
|
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||||
|
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
IA_OP(half, uint32_t, ia_u32_f16)
|
||||||
|
IA_OP(half, uint8_t, ia_u8_f16)
|
||||||
|
|
||||||
|
IA_OP(float, int64_t, ia_i64_f32)
|
||||||
|
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||||
|
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||||
|
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint32_t, ia_u32_f32)
|
||||||
|
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||||
|
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||||
|
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint8_t, ia_u8_f32)
|
||||||
|
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||||
|
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||||
|
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
1317
candle-metal-kernels/src/lib.rs
Normal file
1317
candle-metal-kernels/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
139
candle-metal-kernels/src/reduce.metal
Normal file
139
candle-metal-kernels/src/reduce.metal
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
constant int THREADGROUP_SIZE = 256;
|
||||||
|
|
||||||
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &src_numel, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const TYPENAME *src, \
|
||||||
|
device TYPENAME *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint blockDim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = 0; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
*/ \
|
||||||
|
TYPENAME x = shared_memory[tid]; \
|
||||||
|
TYPENAME y = src[idx]; \
|
||||||
|
shared_memory[tid] = FN; \
|
||||||
|
idx += blockDim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
TYPENAME x = shared_memory[tid]; \
|
||||||
|
TYPENAME y = shared_memory[tid + s]; \
|
||||||
|
shared_memory[tid] = FN; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
dst[dst_id] = shared_memory[0]; \
|
||||||
|
} \
|
||||||
|
|
||||||
|
kernel void softmax_float(
|
||||||
|
constant size_t &src_numel,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const float *src,
|
||||||
|
device float *dst,
|
||||||
|
uint id [[ thread_position_in_grid ]],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint blockDim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||||
|
|
||||||
|
shared_memory[tid] = -INFINITY;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
float max = shared_memory[0];
|
||||||
|
|
||||||
|
shared_memory[tid] = 0;
|
||||||
|
|
||||||
|
// Restart
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
const float val = exp(src[idx] - max);
|
||||||
|
dst[idx] = val;
|
||||||
|
shared_memory[tid] += val;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] += shared_memory[tid + s];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float inv_acc = 1/shared_memory[0];
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
dst[idx] *= inv_acc;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_float, float)
|
||||||
|
REDUCE(x * y, fast_mul_float, float)
|
||||||
|
REDUCE(max(x, y), fast_max_float, float)
|
57
candle-metal-kernels/src/ternary.metal
Normal file
57
candle-metal-kernels/src/ternary.metal
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
#
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &numel, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t *strides_t, \
|
||||||
|
constant size_t *strides_f, \
|
||||||
|
device const ID_TYPENAME *ids, \
|
||||||
|
device const TYPENAME *t, \
|
||||||
|
device const TYPENAME *f, \
|
||||||
|
device TYPENAME *out ,\
|
||||||
|
uint i [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
|
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||||
|
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||||
|
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||||
|
} \
|
||||||
|
|
||||||
|
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||||
|
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||||
|
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||||
|
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||||
|
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||||
|
//
|
||||||
|
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||||
|
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||||
|
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||||
|
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||||
|
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||||
|
|
||||||
|
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||||
|
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||||
|
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||||
|
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||||
|
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
78
candle-metal-kernels/src/unary.metal
Normal file
78
candle-metal-kernels/src/unary.metal
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
|
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
device const TYPENAME *input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (thread_position_in_grid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||||
|
}\
|
||||||
|
kernel void FN_NAME_STRIDED( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
device const TYPENAME *input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (thread_position_in_grid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define UNARY_OP(NAME) \
|
||||||
|
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||||
|
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||||
|
|
||||||
|
#define BFLOAT_UNARY_OP(NAME) \
|
||||||
|
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||||
|
|
||||||
|
|
||||||
|
UNARY_OP(cos)
|
||||||
|
UNARY_OP(sin)
|
||||||
|
UNARY_OP(sqr)
|
||||||
|
UNARY_OP(sqrt)
|
||||||
|
UNARY_OP(neg)
|
||||||
|
UNARY_OP(exp)
|
||||||
|
UNARY(id, float, copy_float, copy_float_strided)
|
||||||
|
UNARY(id, half, copy_half, copy_half_strided)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
BFLOAT_UNARY_OP(cos)
|
||||||
|
BFLOAT_UNARY_OP(sin)
|
||||||
|
BFLOAT_UNARY_OP(sqr)
|
||||||
|
BFLOAT_UNARY_OP(sqrt)
|
||||||
|
BFLOAT_UNARY_OP(neg)
|
||||||
|
BFLOAT_UNARY_OP(exp)
|
||||||
|
|
||||||
|
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||||
|
#endif
|
@ -9,6 +9,7 @@ pub struct Embedding {
|
|||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||||
|
// todo!("Embedding {embeddings}");
|
||||||
Self {
|
Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -71,11 +71,13 @@ impl PyDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
|
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
enum PyDevice {
|
enum PyDevice {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda,
|
Cuda,
|
||||||
|
Metal,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyDevice {
|
impl PyDevice {
|
||||||
@ -83,6 +85,7 @@ impl PyDevice {
|
|||||||
match device {
|
match device {
|
||||||
Device::Cpu => Self::Cpu,
|
Device::Cpu => Self::Cpu,
|
||||||
Device::Cuda(_) => Self::Cuda,
|
Device::Cuda(_) => Self::Cuda,
|
||||||
|
Device::Metal(_) => Self::Metal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,6 +101,15 @@ impl PyDevice {
|
|||||||
*device = Some(d.clone());
|
*device = Some(d.clone());
|
||||||
Ok(d)
|
Ok(d)
|
||||||
}
|
}
|
||||||
|
Self::Metal => {
|
||||||
|
let mut device = METAL_DEVICE.lock().unwrap();
|
||||||
|
if let Some(device) = device.as_ref() {
|
||||||
|
return Ok(device.clone());
|
||||||
|
};
|
||||||
|
let d = Device::new_metal(0).map_err(wrap_err)?;
|
||||||
|
*device = Some(d.clone());
|
||||||
|
Ok(d)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
|
|||||||
let str = match self {
|
let str = match self {
|
||||||
PyDevice::Cpu => "cpu",
|
PyDevice::Cpu => "cpu",
|
||||||
PyDevice::Cuda => "cuda",
|
PyDevice::Cuda => "cuda",
|
||||||
|
PyDevice::Metal => "metal",
|
||||||
};
|
};
|
||||||
str.to_object(py)
|
str.to_object(py)
|
||||||
}
|
}
|
||||||
|
@ -156,6 +156,7 @@ impl CausalSelfAttention {
|
|||||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
todo!("X {x1}");
|
||||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||||
@ -173,6 +174,7 @@ impl CausalSelfAttention {
|
|||||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
|
todo!("X {q}");
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if self.cache.use_kv_cache {
|
||||||
@ -295,6 +297,7 @@ impl Block {
|
|||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||||
|
todo!("---X {}", x);
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -327,6 +330,7 @@ impl Llama {
|
|||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, _seq_len) = x.dims2()?;
|
let (_b_sz, _seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
|
//println!("Embeddings {}", self.wte.embeddings());
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx)?;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user