mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
28 Commits
0.9.0-alph
...
tmp_broken
Author | SHA1 | Date | |
---|---|---|---|
eb24875856 | |||
3f662e54cd | |||
480a3e22e6 | |||
0c24a885a6 | |||
76d3116f5d | |||
1367e0278b | |||
7ff17d92b3 | |||
cd68c96803 | |||
4d87305c48 | |||
677495f9b8 | |||
dedc8c3656 | |||
63cce76b84 | |||
634a4e7168 | |||
8124d1003f | |||
6d4c8c0707 | |||
e6d33a8efb | |||
c921cc3784 | |||
d4d6850c78 | |||
e708d35e7f | |||
0794e70a19 | |||
f57e3164ae | |||
7161002a34 | |||
82cce52e73 | |||
71fcb31873 | |||
198009453a | |||
492d164235 | |||
2d84c16fed | |||
4525b7b52a |
@ -55,6 +55,7 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -12,7 +12,10 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -39,3 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:candle-metal-kernels", "dep:metal"]
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::cpu_backend::CpuDevice;
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
Metal(crate::MetalDevice),
|
||||
}
|
||||
|
||||
pub trait NdArray {
|
||||
@ -103,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
@ -128,10 +130,15 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
||||
Self::Cpu => CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
Self::Metal(m) => m.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,21 +154,16 @@ impl Device {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => device.location(),
|
||||
Device::Metal(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
matches!(self, Self::Cpu)
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
matches!(self, Self::Cuda(_))
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
@ -188,6 +190,11 @@ impl Device {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal rand_uniform not implemented")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,6 +223,10 @@ impl Device {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,6 +249,10 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -251,6 +266,10 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,6 +281,11 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -273,6 +297,11 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal => todo!(),
|
||||
};
|
||||
|
||||
write!(
|
||||
|
201
candle-core/src/dummy_metal_backend.rs
Normal file
201
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,201 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
};
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -152,6 +152,9 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("the candle crate has not been built with metal support")]
|
||||
NotCompiledWithMetalSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
@ -159,6 +162,9 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] metal_backend::MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
|
@ -52,6 +52,10 @@ mod dummy_cuda_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
@ -87,6 +91,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
806
candle-core/src/metal_backend.rs
Normal file
806
candle-core/src/metal_backend.rs
Normal file
@ -0,0 +1,806 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::{void_ptr, Kernels, Source};
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::encode_gemm;
|
||||
use metal::mps::Float32;
|
||||
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
#[error(transparent)]
|
||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalError {
|
||||
fn msg<S: AsRef<str>>(msg: S) -> Self {
|
||||
MetalError::Message(msg.as_ref().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||
// self.device.as_ref()
|
||||
// }
|
||||
|
||||
pub fn id(&self) -> u64 {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||
// debug!("Allocate 1 - buffer size {size}");
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self.dtype {
|
||||
DType::F32 => Ok(CpuStorage::F32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
dtype => todo!("Unsupported dtype {dtype:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
assert!(layout.is_contiguous());
|
||||
assert_eq!(dtype, DType::F32);
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
});
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
debug!("TODO reduce_op {op:?}");
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
let mut dst_el: usize = 1;
|
||||
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||
if !sum_dims.contains(&dim_idx) {
|
||||
dst_el *= d;
|
||||
dims.push(d);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
// let el_to_sum_per_block = src_el / dst_el;
|
||||
// // The reduction loop requires the shared array to be properly initialized and for
|
||||
// // this we want the number of threads to be a power of two.
|
||||
// let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
|
||||
// let cfg = LaunchConfig {
|
||||
// // TODO: Maybe use grid_y if the output is too large?
|
||||
// // TODO: Specialized implementation when reducing on no or all dimensions or when
|
||||
// // reducing only aggregate a small number of elements together.
|
||||
// grid_dim: (dst_el as u32, 1, 1),
|
||||
// block_dim: (block_dim as u32, 1, 1),
|
||||
// shared_mem_bytes: 0,
|
||||
// };
|
||||
// let ds = dev
|
||||
// .htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
// .w()?;
|
||||
// let src = &src.slice(layout.start_offset()..);
|
||||
// let (name, check_empty, return_index) = match self.1 {
|
||||
// ReduceOp::Sum => ("fast_sum", false, false),
|
||||
// ReduceOp::Min => ("fast_min", true, false),
|
||||
// ReduceOp::Max => ("fast_max", true, false),
|
||||
// ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||
// ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||
// };
|
||||
// if check_empty && layout.shape().elem_count() == 0 {
|
||||
// Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
// }
|
||||
// let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// if return_index {
|
||||
// // SAFETY: filled in by the follow up kernel.
|
||||
// let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// // SAFETY: ffi.
|
||||
// unsafe { func.launch(cfg, params) }.w()?;
|
||||
// Ok(S::U32(out))
|
||||
// } else {
|
||||
// // SAFETY: filled in by the follow up kernel.
|
||||
// let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// // SAFETY: ffi.
|
||||
// unsafe { func.launch(cfg, params) }.w()?;
|
||||
// Ok(wrap(out))
|
||||
// }
|
||||
// Ok(self.clone())
|
||||
// todo!()
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!(
|
||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||
self.dtype,
|
||||
dtype
|
||||
);
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"cast {:?} - {:?} - {:?} - {:?}",
|
||||
dtype,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
// TODO remove
|
||||
// return Ok(Self {
|
||||
// buffer,
|
||||
// device: device.clone(),
|
||||
// dtype,
|
||||
// });
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"Unary {:?} - {:?} - {:?} - {:?}",
|
||||
B::KERNEL,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("badd", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
&rhs.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"Binary {:?} - {:?} - {:?} - {:?}",
|
||||
B::KERNEL,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
debug!("TODO where_cond");
|
||||
Ok(rhs.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
debug!(
|
||||
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
|
||||
self.buffer.length(),
|
||||
ids.buffer.length(),
|
||||
);
|
||||
let src = self;
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
// let src = match src_l.contiguous_offsets() {
|
||||
// Some((o1, o2)) => src.slice(o1..o2),
|
||||
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
// };
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_size = src_l.dims()[dim];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = false;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if src_l.is_contiguous() {
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let blip = command_buffer.new_blit_command_encoder();
|
||||
blip.copy_from_buffer(
|
||||
&self.buffer,
|
||||
src_l.start_offset() as u64,
|
||||
&dst.buffer,
|
||||
dst_offset as u64,
|
||||
self.buffer.length(),
|
||||
);
|
||||
} else {
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
&src_l.stride(),
|
||||
src_l.start_offset(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub(crate) fn matmul_t(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = true;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
pub(crate) fn matmul_generic(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
if b != 1 {
|
||||
debug!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
debug!(
|
||||
"TODO non contiguous matmul yet {:?} {:?}",
|
||||
lhs_l.is_contiguous(),
|
||||
rhs_l.is_contiguous()
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
|
||||
debug!("GEMM");
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
encode_gemm::<Float32, Float32, Float32>(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut out_buffer,
|
||||
m as NSUInteger,
|
||||
n as NSUInteger,
|
||||
k as NSUInteger,
|
||||
alpha as f32,
|
||||
beta as f32,
|
||||
Some(b as NSUInteger),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
_ => todo!("Unimplemented matmul for this pair"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
// let capture = metal::CaptureManager::shared();
|
||||
// let descriptor = metal::CaptureDescriptor::new();
|
||||
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
// descriptor.set_capture_device(&device);
|
||||
// let mut dir = std::env::current_dir()?;
|
||||
// dir.push("out.gputrace");
|
||||
// descriptor.set_output_url(dir);
|
||||
|
||||
// capture
|
||||
// .start_capture(&descriptor)
|
||||
// .map_err(MetalError::from)?;
|
||||
let command_queue = device.new_command_queue();
|
||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
// command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
todo!("set_seed")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
||||
option,
|
||||
),
|
||||
};
|
||||
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -174,6 +174,18 @@ pub trait CustomOp1 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
@ -209,6 +221,20 @@ pub trait CustomOp2 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -251,6 +277,22 @@ pub trait CustomOp3 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,11 +121,12 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
super::QTensor::new(data.to_vec(), dims, device)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -133,6 +134,7 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
@ -144,18 +146,38 @@ pub fn qtensor_from_ggml(
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -163,6 +185,7 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -187,7 +210,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -201,7 +224,10 @@ pub struct Content {
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -211,7 +237,7 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -57,6 +57,7 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
@ -69,7 +70,12 @@ impl TensorInfo {
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -450,12 +456,13 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
use tracing::debug;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
@ -14,6 +15,7 @@ pub mod utils;
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
device: Device,
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
@ -170,17 +172,20 @@ impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
let device = src.device();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
@ -197,6 +202,7 @@ impl QTensor {
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -212,7 +218,12 @@ impl QTensor {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
// TODO Skip the CPU part on metal
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
@ -305,6 +316,46 @@ impl crate::CustomOp1 for QTensor {
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
debug!("TODO qmatmul");
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
// let storage = storage.as_slice::<f32>()?;
|
||||
// let storage =
|
||||
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
// self.matmul_t(
|
||||
// (dst_shape.elem_count() / n, k, n),
|
||||
// storage,
|
||||
// &mut dst_storage,
|
||||
// )?;
|
||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
if let Device::Metal(device) = &self.device {
|
||||
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
|
||||
} else {
|
||||
crate::bail!("qtensor not on metal device")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
Metal(MetalStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -18,6 +19,10 @@ impl Storage {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,6 +30,7 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,6 +38,7 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
Self::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,6 +72,10 @@ impl Storage {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,6 +89,10 @@ impl Storage {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,6 +106,10 @@ impl Storage {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,6 +131,10 @@ impl Storage {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -135,6 +158,10 @@ impl Storage {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,6 +175,10 @@ impl Storage {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,6 +192,10 @@ impl Storage {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||
Ok((Self::Metal(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,6 +216,10 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -205,6 +244,10 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -219,6 +262,10 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +286,10 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -270,6 +321,10 @@ impl Storage {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -297,6 +352,10 @@ impl Storage {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -324,6 +383,10 @@ impl Storage {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -348,6 +411,10 @@ impl Storage {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,6 +433,10 @@ impl Storage {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -379,6 +450,10 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,6 +467,10 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,6 +494,10 @@ impl Storage {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -441,6 +524,10 @@ impl Storage {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -465,6 +552,10 @@ impl Storage {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -489,6 +580,10 @@ impl Storage {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -510,6 +605,10 @@ impl Storage {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -537,6 +636,10 @@ impl Storage {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -556,6 +659,9 @@ impl Storage {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -523,6 +523,7 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1448,6 +1449,7 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1478,6 +1480,7 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1518,6 +1521,7 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1837,6 +1841,9 @@ impl Tensor {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
let tensor_ = Tensor_ {
|
||||
|
@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn metal_is_available() -> bool {
|
||||
cfg!(feature = "metal")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
@ -51,6 +51,7 @@ anyhow = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
@ -232,11 +232,13 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(false)?;
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
tracing_subscriber::fmt::init();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -276,10 +278,10 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file)?;
|
||||
let model = ggml_file::Content::read(&mut file, &device)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
@ -307,7 +309,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
@ -366,10 +368,13 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
// logits_processor.sample(&logits)?
|
||||
15043
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
@ -379,7 +384,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
@ -392,7 +397,10 @@ fn main() -> anyhow::Result<()> {
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
// TODO Remove this once implementation is finished.
|
||||
// let logits = logits.ones_like()?;
|
||||
// next_token = logits_processor.sample(&logits)?;
|
||||
let next_token = 15043;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if next_token == eos_token {
|
||||
|
@ -2,17 +2,30 @@ pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
|
||||
use candle::utils::{cuda_is_available, metal_is_available};
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else {
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
if !device.is_cuda() {
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
||||
}
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!(
|
||||
"Running on CPU, to run on GPU, build this example with `--features cuda`"
|
||||
);
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
Ok(device)
|
||||
}
|
||||
}
|
||||
|
||||
|
17
candle-metal-kernels/Cargo.toml
Normal file
17
candle-metal-kernels/Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
metal = { workspace = true }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
half = { workspace = true }
|
3
candle-metal-kernels/README.md
Normal file
3
candle-metal-kernels/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# candle-metal-kernels
|
||||
|
||||
This crate contains Metal kernels used from candle.
|
44
candle-metal-kernels/src/affine.metal
Normal file
44
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,44 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = input[i] * mul + add; \
|
||||
} \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
#endif
|
78
candle-metal-kernels/src/binary.metal
Normal file
78
candle-metal-kernels/src/binary.metal
Normal file
@ -0,0 +1,78 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
TYPENAME x = left[i]; \
|
||||
TYPENAME y = right[i]; \
|
||||
output[i] = OUT_TYPENAME(FN); \
|
||||
} \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *left_strides, \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
|
||||
output[i] = OUT_TYPENAME(FN); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
#endif
|
58
candle-metal-kernels/src/cast.metal
Normal file
58
candle-metal-kernels/src/cast.metal
Normal file
@ -0,0 +1,58 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = RIGHT_TYPENAME(input[i]); \
|
||||
} \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
75
candle-metal-kernels/src/indexing.metal
Normal file
75
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,75 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
uint threadgroup_size [[threads_per_threadgroup]],
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
||||
uint thread_index [[thread_index_in_threadgroup]]
|
||||
) {
|
||||
|
||||
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
||||
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
862
candle-metal-kernels/src/lib.rs
Normal file
862
candle-metal-kernels/src/lib.rs
Normal file
@ -0,0 +1,862 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||
MTLSize,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
Binary,
|
||||
Cast,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
}
|
||||
)+
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
}
|
||||
)+
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
}
|
||||
|
||||
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||
// let mut l = HashMap::new();
|
||||
// l.insert("affine", AFFINE);
|
||||
// l.insert("indexing", INDEXING);
|
||||
// l.insert("unary", UNARY);
|
||||
// l
|
||||
// });
|
||||
//
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalKernelError {
|
||||
#[error("Could not lock kernel map: {0}")]
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0}")]
|
||||
LoadFunctionError(String),
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
fn from(e: std::sync::PoisonError<T>) -> Self {
|
||||
Self::LockError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
type KernelMap<T> = HashMap<&'static str, T>;
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Functions = KernelMap<Function>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
funcs: RwLock<Functions>,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let funcs = RwLock::new(Functions::new());
|
||||
Self { libraries, funcs }
|
||||
}
|
||||
|
||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||
// let kernels = Self::new();
|
||||
// kernels.load_libraries(device)?;
|
||||
// Ok(kernels)
|
||||
// }
|
||||
|
||||
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||
// for name in LIBRARY_SOURCES.keys() {
|
||||
// self.load_library(device, name)?;
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
// LIBRARY_SOURCES.get(name).cloned()
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
Source::Binary => BINARY,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_library(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
) -> Result<Library, MetalKernelError> {
|
||||
let mut libraries = self.libraries.write()?;
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_function(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let mut funcs = self.funcs.write()?;
|
||||
if let Some(func) = funcs.get(name) {
|
||||
Ok(func.clone())
|
||||
} else {
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
funcs.insert(name, func.clone());
|
||||
Ok(func)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&length));
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let num_dims: usize = shape.len() as usize;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
|
||||
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
||||
encoder.set_bytes(
|
||||
2,
|
||||
(shape.len() * std::mem::size_of::<usize>()) as u64,
|
||||
shape.as_ptr() as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
3,
|
||||
(strides.len() * std::mem::size_of::<usize>()) as u64,
|
||||
strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
encoder.set_buffer(4, Some(&input), offset as u64);
|
||||
encoder.set_buffer(5, Some(&output), output_offset as u64);
|
||||
|
||||
let width = output.length();
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_binary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: binary::contiguous::Kernel,
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&length));
|
||||
encoder.set_buffer(1, Some(&left), 0);
|
||||
encoder.set_buffer(2, Some(&right), 0);
|
||||
encoder.set_buffer(3, Some(&output), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_binary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: binary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
left_input: &Buffer,
|
||||
left_strides: &[usize],
|
||||
left_offset: usize,
|
||||
right_input: &Buffer,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let num_dims: usize = shape.len() as usize;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
|
||||
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
||||
encoder.set_bytes(
|
||||
2,
|
||||
(shape.len() * std::mem::size_of::<usize>()) as u64,
|
||||
shape.as_ptr() as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
3,
|
||||
(left_strides.len() * std::mem::size_of::<usize>()) as u64,
|
||||
left_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
4,
|
||||
(right_strides.len() * std::mem::size_of::<usize>()) as u64,
|
||||
right_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
encoder.set_buffer(5, Some(&left_input), left_offset as u64);
|
||||
encoder.set_buffer(6, Some(&right_input), right_offset as u64);
|
||||
encoder.set_buffer(7, Some(&output), 0);
|
||||
|
||||
let width = output.length();
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_cast_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&length));
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
(v as *const T).cast()
|
||||
}
|
||||
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
|
||||
encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
|
||||
encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
|
||||
encoder.set_buffer(3, Some(&input), 0);
|
||||
encoder.set_buffer(4, Some(&output), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use half::f16;
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use std::mem;
|
||||
|
||||
fn device() -> Device {
|
||||
Device::system_default().unwrap()
|
||||
}
|
||||
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = device.new_buffer_with_data(
|
||||
x.as_ptr() as *const core::ffi::c_void,
|
||||
(x.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let right = device.new_buffer_with_data(
|
||||
y.as_ptr() as *const core::ffi::c_void,
|
||||
(y.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((x.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(x.len())
|
||||
}
|
||||
|
||||
fn run_strided<T: Clone>(
|
||||
v: &[T],
|
||||
kernel: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
kernel,
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&mut output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32_strided() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
// Shape = [6], strides = [1];
|
||||
let shape = vec![6];
|
||||
let strides = vec![1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Contiguous
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Transposed
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![1, 3];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Very large
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
let right = vec![2.0f32, 3.1, 4.2];
|
||||
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
||||
let expected: Vec<_> = left
|
||||
.iter()
|
||||
.zip(right.iter())
|
||||
.map(|(&x, &y)| x + y)
|
||||
.collect();
|
||||
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<U>()) as u64, options);
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<U>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
call_affine(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
size,
|
||||
&input,
|
||||
&mut output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||
|
||||
let input = [1.0f32; 40_000];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
let ids_dim_size = index.len() as u32;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
||||
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
||||
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
||||
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
||||
|
||||
encoder.set_buffer(0, Some(&index_buffer), 0);
|
||||
encoder.set_buffer(1, Some(&inputs_buffer), 0);
|
||||
encoder.set_buffer(2, Some(&outputs_buffer), 0);
|
||||
|
||||
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
|
||||
encoder.set_bytes(4, 4, void_ptr(&left_size));
|
||||
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
|
||||
encoder.set_bytes(6, 4, void_ptr(&right_size));
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: right.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f16() {
|
||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let results = run(&v, unary::contiguous::cos::HALF);
|
||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||
}
|
||||
}
|
82
candle-metal-kernels/src/unary.metal
Normal file
82
candle-metal-kernels/src/unary.metal
Normal file
@ -0,0 +1,82 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = TYPENAME(FN(input[i])); \
|
||||
} \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
#endif
|
@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
@ -28,4 +29,5 @@ clap = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
metal = ["candle/metal"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
@ -1,5 +1,6 @@
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use rayon::prelude::*;
|
||||
use tracing::debug;
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
@ -191,6 +192,16 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
debug!("TODO softmax-last-dim");
|
||||
Ok((storage.clone(), layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -81,6 +81,7 @@ impl PyDevice {
|
||||
match device {
|
||||
Device::Cpu => Self::Cpu,
|
||||
Device::Cuda(_) => Self::Cuda,
|
||||
Device::Metal(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,5 +28,6 @@ wav = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
|
@ -16,7 +16,7 @@ struct RmsNorm {
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
let scale = scale.dequantize(scale.device())?;
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
@ -112,6 +112,7 @@ impl LayerWeights {
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
@ -145,9 +146,12 @@ impl LayerWeights {
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
// println!("att {:?}", att.dtype());
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
// println!("mask {:?}", mask.dtype());
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
@ -181,13 +185,23 @@ pub struct ModelWeights {
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
||||
// let idx_theta = Tensor::new(range.as_slice(), device)?
|
||||
// .reshape((MAX_SEQ_LEN, 1))?
|
||||
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
||||
// CPU as f32 seems just as fast
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
@ -197,12 +211,11 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
@ -257,8 +270,8 @@ impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
@ -276,24 +289,31 @@ impl ModelWeights {
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let attention_norm =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
@ -331,14 +351,14 @@ impl ModelWeights {
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
@ -346,7 +366,7 @@ impl ModelWeights {
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let mask = self.mask(seq_len, x.device())?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
|
@ -10,12 +10,12 @@ pub struct VarBuilder {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
let tensor = content.tensor(&mut file, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
@ -25,12 +25,12 @@ impl VarBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
|
Reference in New Issue
Block a user