diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 2b9ad619..292545c5 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -223,10 +223,9 @@ impl Device { let storage = device.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cuda(storage)) } - Device::Metal(_device) => { - // let storage = device.rand_normal(shape, dtype, mean, std)?; - // Ok(Storage::Metal(storage)) - bail!("Metal rand_normal not implemented") + Device::Metal(device) => { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Metal(storage)) } } } @@ -250,10 +249,9 @@ impl Device { let storage = device.ones_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } - Device::Metal(_device) => { - // let storage = device.ones_impl(shape, dtype)?; - // Ok(Storage::Metal(storage)) - bail!("Metal ones not implemented") + Device::Metal(device) => { + let storage = device.ones_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) } } } @@ -268,10 +266,9 @@ impl Device { let storage = device.zeros_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } - Device::Metal(_device) => { - // let storage = device.zeros_impl(shape, dtype)?; - // Ok(Storage::Metal(storage)) - bail!("Metal zeros not implemented") + Device::Metal(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) } } } @@ -284,11 +281,10 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } - Device::Metal(_device) => { - // let storage = array.to_cpu_storage(); - // let storage = device.storage_from_cpu_storage(&storage)?; - // Ok(Storage::Metal(storage)) - bail!("Metal storage not implemented") + Device::Metal(device) => { + let storage = array.to_cpu_storage(); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) } } } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 2cbd6cea..06e88755 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,6 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; pub use candle_metal; use metal; +use core::mem; +use half::{f16, bf16}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -43,8 +45,10 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { - pub buffer: metal::Buffer, - pub device: metal::Device, + buffer: metal::Buffer, + device: MetalDevice, + dtype: DType + } impl BackendStorage for MetalStorage { @@ -55,11 +59,11 @@ impl BackendStorage for MetalStorage { } fn dtype(&self) -> DType { - todo!() + self.dtype } fn device(&self) -> &Self::Device { - todo!() + &self.device } fn to_cpu_storage(&self) -> Result { @@ -86,8 +90,8 @@ impl BackendStorage for MetalStorage { todo!() } - fn to_dtype(&self, _: &Layout, _: DType) -> Result { - todo!() + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype) } fn unary_impl(&self, _: &Layout) -> Result { @@ -182,12 +186,19 @@ impl BackendStorage for MetalStorage { fn matmul( &self, - _: &Self, - _: (usize, usize, usize, usize), - _: &Layout, - _: &Layout, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { - todo!() + let elem_count = b * m * n; + let dev = &self.device; + match (self.dtype, rhs.dtype){ + (DType::F32, DType::F32) => { + todo!("MATMUL {b} {m} {n} {k}"); + } + _ => todo!("Unimplemented matmul for this pair") + } } fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { @@ -223,8 +234,60 @@ impl BackendDevice for MetalDevice { todo!() } - fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { - todo!("Storage") + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; + let buffer = match storage { + CpuStorage::U8(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::U32(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::I64(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::BF16(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::F16(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::F32(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + CpuStorage::F64(storage) => { + self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option + ) + } + }; + Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()}) } fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 0b1e15b5..1fb32701 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -308,7 +308,7 @@ fn main() -> anyhow::Result<()> { | Which::L70b | Which::L70bChat => 8, }; - ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? + ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)? } }; println!("model built"); diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 678c5800..5d259072 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; @@ -181,28 +181,31 @@ pub struct ModelWeights { span_output: tracing::Span, } -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { +fn precomput_freqs_cis(head_dim: usize, freq_base: f32, device: &Device) -> Result<(Tensor, Tensor)> { let theta: Vec<_> = (0..head_dim) .step_by(2) .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let theta = Tensor::new(theta.as_slice(), device)?; + let range: Vec = (0..MAX_SEQ_LEN).map(|r| r as f32).collect(); + let idx_theta = Tensor::new(range.as_slice(), device)?.reshape((MAX_SEQ_LEN, 1))?.matmul(&theta.reshape((1, theta.elem_count()))?)?; + // TODO This change avoids allocating on Metal and then casting since allocating directly on + // CPU as f32 seems just as fast + // let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + // .to_dtype(DType::F32)? + // .reshape((MAX_SEQ_LEN, 1))? + // .matmul(&theta.reshape((1, theta.elem_count()))?)?; let cos = idx_theta.cos()?; let sin = idx_theta.sin()?; Ok((cos, sin)) } impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { - let cpu = &Device::Cpu; + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); @@ -276,7 +279,7 @@ impl ModelWeights { let rope_freq_base = md_get("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -331,14 +334,14 @@ impl ModelWeights { }) } - fn mask(&mut self, t: usize) -> Result { + fn mask(&mut self, t: usize, device: &Device) -> Result { if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + let mask = Tensor::from_slice(&mask, (t, t), device)?; self.masks.insert(t, mask.clone()); Ok(mask) } @@ -346,7 +349,7 @@ impl ModelWeights { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; + let mask = self.mask(seq_len, x.device())?; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() {