diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 0807176e..2b9ad619 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,10 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn new_metal(ordinal: usize) -> Result { + Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) + } + pub fn set_seed(&self, seed: u64) -> Result<()> { match self { Self::Cpu => CpuDevice.set_seed(seed), @@ -297,11 +301,10 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } - Device::Metal(_device) => { - // let storage = S::to_cpu_storage_owned(data); - // let storage = device.storage_from_cpu_storage(&storage)?; - // Ok(Storage::Metal(storage)) - bail!("Metal storage_owned not implemented") + Device::Metal(device) => { + let storage = S::to_cpu_storage_owned(data); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) } } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 209ec9a7..8a629e46 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -49,7 +49,8 @@ mod device; pub mod display; mod dtype; mod dummy_cuda_backend; -mod dummy_metal_backend; +#[cfg(feature = "metal")] +pub mod metal_backend; pub mod error; mod indexer; pub mod layout; @@ -71,9 +72,6 @@ pub mod test_utils; pub mod utils; mod variable; -#[cfg(not(feature = "cuda"))] -pub use dummy_metal_backend::{MetalDevice, MetalStorage}; - pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, FloatDType, IntDType, WithDType}; @@ -93,6 +91,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage}; #[cfg(not(feature = "cuda"))] pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; +#[cfg(feature = "metal")] +pub use metal_backend::{MetalDevice, MetalStorage}; + +#[cfg(not(feature = "metal"))] +pub use dummy_metal_backend::{MetalDevice, MetalStorage}; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 00d236e3..2cbd6cea 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -198,20 +198,21 @@ impl BackendStorage for MetalStorage { impl BackendDevice for MetalDevice { type Storage = MetalStorage; - fn new(_ordinal: usize) -> Result { - todo!() + fn new(ordinal: usize) -> Result { + let device = metal::Device::all().swap_remove(ordinal); + Ok(Self{device }) } fn set_seed(&self, _seed: u64) -> Result<()> { - todo!() + todo!("set_seed") } fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Metal } - fn same_device(&self, _rhs: &Self) -> bool { - todo!() + fn same_device(&self, rhs: &Self) -> bool { + self.device.registry_id() == rhs.device.registry_id() } fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { @@ -223,7 +224,7 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { - todo!() + todo!("Storage") } fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index e1168c2e..9ea50bf6 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,5 +1,5 @@ #![allow(clippy::redundant_closure_call)] -use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor}; +use crate::{CpuStorage, CudaStorage, MetalStorage, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; use num_traits::float::Float; @@ -174,6 +174,14 @@ pub trait CustomOp1 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd(&self, _storage: &MetalStorage, _layout: &Layout) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + /// This function takes as argument the argument `arg` used in the forward pass, the result /// produced by the forward operation `res` and the gradient of the result `grad_res`. /// The function should return the gradient of the argument. @@ -209,6 +217,20 @@ pub trait CustomOp2 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + fn bwd( &self, _arg1: &Tensor, @@ -251,6 +273,22 @@ pub trait CustomOp3 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + fn bwd( &self, _arg1: &Tensor, diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 1dd3d9c0..5b2e3b64 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,7 @@ //! Support for the GGML file format. use super::{k_quants, GgmlDType}; -use crate::Result; +use crate::{Result, Device}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -121,11 +121,12 @@ fn from_raw_data( raw_data: &[u8], size_in_bytes: usize, dims: Vec, + device: &Device, ) -> Result { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::(); let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - super::QTensor::new(data.to_vec(), dims) + super::QTensor::new(data.to_vec(), dims, device) } /// Creates a [Tensor] from a raw GGML tensor. @@ -133,6 +134,7 @@ pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec, + device: &Device, ) -> Result { let tensor_elems = dims.iter().product::(); let blck_size = ggml_dtype.blck_size(); @@ -144,18 +146,18 @@ pub fn qtensor_from_ggml( let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size(); match ggml_dtype { - GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q8_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q2K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q3K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q6K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_0 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_1 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q5_0 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q5_1 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q8_0 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q2K => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q3K => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4K => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q5K => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q6K => from_raw_data::(raw_data, size_in_bytes, dims, device), _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } @@ -163,6 +165,7 @@ pub fn qtensor_from_ggml( fn read_one_tensor( reader: &mut R, magic: VersionedMagic, + device: &Device, ) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::()?; let name_len = reader.read_u32::()?; @@ -187,7 +190,7 @@ fn read_one_tensor( // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } @@ -201,7 +204,7 @@ pub struct Content { } impl Content { - pub fn read(reader: &mut R) -> Result { + pub fn read(reader: &mut R, device: &Device) -> Result { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -211,7 +214,7 @@ impl Content { let mut tensors = HashMap::new(); while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic)?; + let (name, tensor) = read_one_tensor(reader, magic, device)?; tensors.insert(name, tensor); } Ok(Self { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 3a5f2030..7af5d394 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -3,7 +3,7 @@ //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; -use crate::Result; +use crate::{Result, Device}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -57,6 +57,7 @@ impl TensorInfo { &self, reader: &mut R, tensor_data_offset: u64, + device: &Device, ) -> Result { let tensor_elems = self.shape.elem_count(); let blck_size = self.ggml_dtype.blck_size(); @@ -69,7 +70,7 @@ impl TensorInfo { let mut raw_data = vec![0u8; size_in_bytes]; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.read_exact(&mut raw_data)?; - super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec()) + super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec(), device) } } @@ -450,12 +451,13 @@ impl Content { &self, reader: &mut R, name: &str, + device: &Device, ) -> Result { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, None => crate::bail!("cannot find tensor-infor for {name}"), }; - tensor_info.read(reader, self.tensor_data_offset) + tensor_info.read(reader, self.tensor_data_offset, device) } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 58f261b4..4998b114 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -14,6 +14,7 @@ pub mod utils; pub use k_quants::GgmlType; pub struct QTensor { + device: Device, data: Box, shape: Shape, } @@ -170,17 +171,20 @@ impl QTensor { pub fn new, T: k_quants::GgmlType + Send + Sync + 'static>( data: Vec, shape: S, + device: &Device, ) -> Result { let shape = shape.into(); check_shape::(&shape)?; Ok(Self { data: Box::new(data), shape, + device: device.clone() }) } pub fn quantize(src: &Tensor) -> Result { let shape = src.shape(); + let device = src.device(); check_shape::(shape)?; let src = src .to_dtype(crate::DType::F32)? @@ -197,6 +201,7 @@ impl QTensor { Ok(Self { data: Box::new(data), shape: shape.clone(), + device: device.clone() }) } @@ -212,7 +217,12 @@ impl QTensor { &self.shape } + pub fn device(&self) -> &Device { + &self.device + } + pub fn dequantize(&self, device: &Device) -> Result { + // TODO Skip the CPU part on metal let mut f32_data = vec![0f32; self.shape.elem_count()]; self.data.to_float(&mut f32_data)?; Tensor::from_vec(f32_data, &self.shape, device) diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 9bd1fed6..c7897243 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,6 +1,6 @@ use crate::backend::BackendStorage; use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp}; -use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; +use crate::{CpuStorage, CudaStorage, MetalStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape pub enum Storage { Cpu(CpuStorage), Cuda(CudaStorage), + Metal(MetalStorage), } impl Storage { @@ -18,6 +19,10 @@ impl Storage { let storage = storage.try_clone(layout)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.try_clone(layout)?; + Ok(Self::Metal(storage)) + } } } @@ -25,6 +30,7 @@ impl Storage { match self { Self::Cpu(_) => Device::Cpu, Self::Cuda(storage) => Device::Cuda(storage.device().clone()), + Self::Metal(storage) => Device::Metal(storage.device().clone()), } } @@ -32,6 +38,7 @@ impl Storage { match self { Self::Cpu(storage) => storage.dtype(), Self::Cuda(storage) => storage.dtype(), + Self::Metal(storage) => storage.dtype(), } } @@ -65,6 +72,10 @@ impl Storage { let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.affine(layout, mul, add)?; + Ok(Self::Metal(storage)) + } } } @@ -78,6 +89,10 @@ impl Storage { let storage = storage.powf(layout, alpha)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Metal(storage)) + } } } @@ -91,6 +106,10 @@ impl Storage { let storage = storage.elu(layout, alpha)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Metal(storage)) + } } } @@ -112,6 +131,10 @@ impl Storage { let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. @@ -135,6 +158,10 @@ impl Storage { let storage = storage.reduce_op(op, layout, s)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.reduce_op(op, layout, s)?; + Ok(Self::Metal(storage)) + } } } @@ -148,6 +175,10 @@ impl Storage { let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.to_dtype(layout, dtype)?; + Ok(Self::Metal(storage)) + } } } @@ -161,6 +192,10 @@ impl Storage { let (storage, shape) = c.cuda_fwd(storage, l)?; Ok((Self::Cuda(storage), shape)) } + Self::Metal(storage) => { + let (storage, shape) = c.metal_fwd(storage, l)?; + Ok((Self::Metal(storage), shape)) + } } } @@ -181,6 +216,10 @@ impl Storage { let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?; Ok((Self::Cuda(s), shape)) } + (Self::Metal(s1), Self::Metal(s2)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?; + Ok((Self::Metal(s), shape)) + } _ => unreachable!(), } } @@ -205,6 +244,10 @@ impl Storage { let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?; Ok((Self::Cuda(s), shape)) } + (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Metal(s), shape)) + } _ => unreachable!(), } } @@ -219,6 +262,10 @@ impl Storage { let storage = storage.unary_impl::(layout)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.unary_impl::(layout)?; + Ok(Self::Metal(storage)) + } } } @@ -239,6 +286,10 @@ impl Storage { let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. @@ -270,6 +321,10 @@ impl Storage { let s = inp.conv1d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -297,6 +352,10 @@ impl Storage { let s = inp.conv2d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -324,6 +383,10 @@ impl Storage { let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -348,6 +411,10 @@ impl Storage { let storage = storage.avg_pool2d(layout, kernel_size, stride)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } } } @@ -366,6 +433,10 @@ impl Storage { let storage = storage.max_pool2d(layout, kernel_size, stride)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } } } @@ -379,6 +450,10 @@ impl Storage { let storage = storage.upsample_nearest1d(layout, sz)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Metal(storage)) + } } } @@ -392,6 +467,10 @@ impl Storage { let storage = storage.upsample_nearest2d(layout, h, w)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Metal(storage)) + } } } @@ -415,6 +494,10 @@ impl Storage { let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } + (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => { + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; + Ok(Self::Metal(storage)) + } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -441,6 +524,10 @@ impl Storage { let storage = s.gather(l, indexes, indexes_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -465,6 +552,10 @@ impl Storage { let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -489,6 +580,10 @@ impl Storage { let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -510,6 +605,10 @@ impl Storage { let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -537,6 +636,10 @@ impl Storage { let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -556,6 +659,7 @@ impl Storage { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), + (Self::Metal(src), Self::Metal(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index fb3c82bb..96f75cf2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -523,6 +523,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1448,6 +1449,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1478,6 +1480,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1518,6 +1521,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index a9c2df0b..78c45a9a 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool { cfg!(feature = "cuda") } +pub fn metal_is_available() -> bool { + cfg!(feature = "metal") +} + pub fn with_avx() -> bool { cfg!(target_feature = "avx") } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 2e9e7f07..0b1e15b5 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -232,6 +232,7 @@ fn main() -> anyhow::Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); + let device = candle_examples::device(false)?; let temperature = if args.temperature == 0. { None } else { @@ -276,10 +277,10 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file)? + ModelWeights::from_gguf(model, &mut file, &device)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file)?; + let model = ggml_file::Content::read(&mut file, &device)?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 4ef97f88..049ea98c 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -3,16 +3,27 @@ pub mod imagenet; pub mod token_output_stream; use candle::{Device, Result, Tensor}; +use candle::utils::{cuda_is_available, metal_is_available}; pub fn device(cpu: bool) -> Result { if cpu { Ok(Device::Cpu) } else { - let device = Device::cuda_if_available(0)?; - if !device.is_cuda() { - println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + if cuda_is_available(){ + Ok(Device::new_cuda(0)?) + }else if metal_is_available(){ + Ok(Device::new_metal(0)?) + }else{ + #[cfg(all(target_os="macos", target_arch="aarch64"))] + { + println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`"); + } + #[cfg(not(all(target_os="macos", target_arch="aarch64")))] + { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(Device::Cpu) } - Ok(device) } } diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 44d89f40..678c5800 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -16,7 +16,7 @@ struct RmsNorm { impl RmsNorm { fn new(scale: QTensor, eps: f32) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; + let scale = scale.dequantize(scale.device())?; let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); Ok(Self { inner, span }) } @@ -257,8 +257,8 @@ impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, + device: &Device ) -> Result { - let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { None => candle::bail!("cannot find {s} in metadata"), Some(v) => Ok(v), @@ -278,22 +278,22 @@ impl ModelWeights { .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps)?; + let output = ct.tensor(reader, "output.weight", device)?; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 810802e8..763d72bd 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -10,12 +10,12 @@ pub struct VarBuilder { } impl VarBuilder { - pub fn from_gguf>(p: P) -> Result { + pub fn from_gguf>(p: P, device: &Device) -> Result { let mut file = std::fs::File::open(p)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut file, tensor_name)?; + let tensor = content.tensor(&mut file, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { @@ -25,12 +25,12 @@ impl VarBuilder { }) } - pub fn from_gguf_buffer(buffer: &[u8]) -> Result { + pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result { let mut cursor = std::io::Cursor::new(buffer); let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut cursor, tensor_name)?; + let tensor = content.tensor(&mut cursor, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self {