diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 73eb9640..3eb7f8b7 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, - Metal, + Metal { gpu_id: usize }, } #[derive(Debug, Clone)] diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 215c28f6..4f5a390e 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -14,7 +14,9 @@ impl Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - _ => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!(f, "Tensor[")?; @@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - crate::DeviceLocation::Metal => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!( diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ed592240..6687534d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -100,11 +100,30 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + // TODO Is this necessary + // self.buffer.synchronize(); match self.dtype { + DType::U8 => Ok(CpuStorage::U8( + self.buffer.read_to_vec(self.buffer.length() as usize / 1), + )), + DType::U32 => Ok(CpuStorage::U32( + self.buffer.read_to_vec(self.buffer.length() as usize / 4), + )), + DType::I64 => Ok(CpuStorage::I64( + self.buffer.read_to_vec(self.buffer.length() as usize / 8), + )), + DType::F16 => Ok(CpuStorage::F16( + self.buffer.read_to_vec(self.buffer.length() as usize / 2), + )), + DType::BF16 => Ok(CpuStorage::BF16( + self.buffer.read_to_vec(self.buffer.length() as usize / 2), + )), DType::F32 => Ok(CpuStorage::F32( self.buffer.read_to_vec(self.buffer.length() as usize / 4), )), - dtype => todo!("Unsupported dtype {dtype:?}"), + DType::F64 => Ok(CpuStorage::F64( + self.buffer.read_to_vec(self.buffer.length() as usize / 8), + )), } } @@ -132,6 +151,7 @@ impl BackendStorage for MetalStorage { ) .unwrap(); command_buffer.commit(); + command_buffer.wait_until_completed(); return Ok(Self { buffer, device: device.clone(), @@ -200,6 +220,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -242,6 +263,7 @@ impl BackendStorage for MetalStorage { } command_buffer.commit(); + command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled(); // debug!( // "cast {:?} - {:?} - {:?}", @@ -289,6 +311,7 @@ impl BackendStorage for MetalStorage { todo!("TODO Implement the kernel calling {}", B::KERNEL); } command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -361,6 +384,7 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -400,6 +424,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, device, @@ -489,6 +514,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); + let out = self.to_cpu_storage().unwrap(); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (left, right) => todo!("index select metal {left:?} {right:?}"), @@ -508,6 +534,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, device: device.clone(), @@ -556,39 +583,42 @@ impl BackendStorage for MetalStorage { if el_count == 0 { return Ok(()); } - if src_l.is_contiguous() { - let command_buffer = self.device.command_queue.new_command_buffer(); - let blip = command_buffer.new_blit_command_encoder(); - blip.copy_from_buffer( - &self.buffer, - src_l.start_offset() as u64, - &dst.buffer, - dst_offset as u64, - self.buffer.length(), - ); - } else { - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => todo!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &self.device.device, - &command_buffer, - &self.device.kernels, - kernel_name, - src_l.dims(), - &self.buffer, - &src_l.stride(), - src_l.start_offset(), - &mut dst.buffer, - dst_offset, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - } + // todo!("Copy strided {:?}", src_l.is_contiguous()); + // if src_l.is_contiguous() { + // let command_buffer = self.device.command_queue.new_command_buffer(); + // let blip = command_buffer.new_blit_command_encoder(); + // blip.copy_from_buffer( + // &self.buffer, + // src_l.start_offset() as u64, + // &dst.buffer, + // dst_offset as u64, + // self.buffer.length(), + // ); + // } else { + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => todo!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + // todo!("Output {:?}", dst.buffer.read_to_vec::(10)); + // } Ok(()) } } @@ -616,28 +646,29 @@ impl MetalStorage { match (self.dtype, rhs.dtype) { (DType::F32, DType::F32) => { let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); - if b != 1 { - // debug!("TODO implement batched matmul for B={b}"); - // bail!("Didn't implemented strided matmul yet"); - return Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }); - } - if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - // debug!( - // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - // lhs_l.is_contiguous(), - // rhs_l.is_contiguous(), - // rhs_l - // ); - return Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }); - } + // if b != 1 { + // // debug!("TODO implement batched matmul for B={b}"); + // crate::bail!("Didn't implemented strided matmul yet"); + // return Ok(Self { + // buffer: out_buffer, + // device: self.device.clone(), + // dtype: self.dtype(), + // }); + //} + // if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { + // // debug!( + // // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + // // lhs_l.is_contiguous(), + // // rhs_l.is_contiguous(), + // // rhs_l + // // ); + // crate::bail!("No not contiguous matmul"); + // return Ok(Self { + // buffer: out_buffer, + // device: self.device.clone(), + // dtype: self.dtype(), + // }); + // } // debug!("TODO GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); @@ -659,7 +690,15 @@ impl MetalStorage { .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled(); + // + let left = self.buffer.read_to_vec::(10); + let right = rhs.buffer.read_to_vec::(10); + let out = out_buffer.read_to_vec::(10); + + println!("{b} {m} {n} {k} "); + println!("{left:?} {right:?} {out:?}"); Ok(Self { buffer: out_buffer, @@ -709,7 +748,9 @@ impl BackendDevice for MetalDevice { } fn location(&self) -> crate::DeviceLocation { - crate::DeviceLocation::Metal + crate::DeviceLocation::Metal { + gpu_id: self.registry_id() as usize, + } } fn same_device(&self, rhs: &Self) -> bool { @@ -767,6 +808,8 @@ impl BackendDevice for MetalDevice { option, ), }; + // TODO is that necessary ? + // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); // debug!("Allocate 2 - buffer size {}", buffer.length()); Ok(Self::Storage { buffer, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f7f66668..3965a2ed 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -157,6 +157,8 @@ pub(crate) fn from_storage>( ) -> Tensor { let dtype = storage.dtype(); let device = storage.device(); + let shape = shape.into(); + // println!("{:?} {storage:?}", shape); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(storage)), @@ -166,7 +168,11 @@ pub(crate) fn from_storage>( dtype, device, }; - Tensor(Arc::new(tensor_)) + let result = Tensor(Arc::new(tensor_)); + // todo!(" from_storage"); + // let result = result.to_device(&Device::Cpu).unwrap(); + // todo!(" {result}"); + result } impl Tensor { diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..11381fbc 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .get_ids() .to_vec(); + println!("{tokens:?}"); + let start_gen = std::time::Instant::now(); - for index in 0.. { + for index in 0..1 { if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + // println!("Input {}", input); + // println!("Input {}", input.to_device(&candle::Device::Cpu)?); let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index ff5ede1a..2585ca62 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -17,3 +17,4 @@ tracing = "0.1.37" [dev-dependencies] half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +rand = "0.8.5" diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/examples/affine.rs new file mode 100644 index 00000000..b8005dc0 --- /dev/null +++ b/candle-metal-kernels/examples/affine.rs @@ -0,0 +1,75 @@ +use candle_metal_kernels::{call_affine, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_affine_bench(&device, &kernels, &f32_1k); + run_affine_bench(&device, &kernels, &f32_10k); + run_affine_bench(&device, &kernels, &f32_100k); +} + +fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + let mul: f32 = 1.2345; + let add: f32 = 2.3456; + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_affine( + &device, + command_buffer, + &kernels, + v.len(), + &input, + &mut output, + mul, + add, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + "affine", + v.len(), + iterations, + total_time, + total_time / iterations + ); +} diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/examples/binary.rs new file mode 100644 index 00000000..af5a8bdc --- /dev/null +++ b/candle-metal-kernels/examples/binary.rs @@ -0,0 +1,182 @@ +use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + binary::contiguous::add::FLOAT, + binary::contiguous::sub::FLOAT, + binary::contiguous::mul::FLOAT, + binary::contiguous::div::FLOAT, + ]; + let f32_skernels = [ + binary::strided::add::FLOAT, + binary::strided::sub::FLOAT, + binary::strided::mul::FLOAT, + binary::strided::div::FLOAT, + ]; + let f16_ckernels = [ + binary::contiguous::add::HALF, + binary::contiguous::sub::HALF, + binary::contiguous::mul::HALF, + binary::contiguous::div::HALF, + ]; + let f16_skernels = [ + binary::strided::add::HALF, + binary::strided::sub::HALF, + binary::strided::mul::HALF, + binary::strided::div::HALF, + ]; + let bf16_ckernels = [ + binary::contiguous::add::BFLOAT, + binary::contiguous::sub::BFLOAT, + binary::contiguous::mul::BFLOAT, + binary::contiguous::div::BFLOAT, + ]; + let bf16_skernels = [ + binary::strided::add::BFLOAT, + binary::strided::sub::BFLOAT, + binary::strided::mul::BFLOAT, + binary::strided::div::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_binary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [binary::contiguous::Kernel; 4], + strided: [binary::strided::Kernel; 4], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &input, + &strides, + offset, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/examples/cast.rs new file mode 100644 index 00000000..090f510d --- /dev/null +++ b/candle-metal-kernels/examples/cast.rs @@ -0,0 +1,84 @@ +use candle_metal_kernels::{call_cast_contiguous, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let contiguous_kernels = ["cast_u32_f32"]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels); +} + +fn run_cast_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: &[&'static str], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_cast_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided? +} diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/examples/unary.rs new file mode 100644 index 00000000..7039c098 --- /dev/null +++ b/candle-metal-kernels/examples/unary.rs @@ -0,0 +1,197 @@ +use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + unary::contiguous::sin::FLOAT, + unary::contiguous::cos::FLOAT, + unary::contiguous::exp::FLOAT, + unary::contiguous::sqr::FLOAT, + unary::contiguous::sqrt::FLOAT, + unary::contiguous::neg::FLOAT, + unary::contiguous::copy::FLOAT, + ]; + let f32_skernels = [ + unary::strided::sin::FLOAT, + unary::strided::cos::FLOAT, + unary::strided::exp::FLOAT, + unary::strided::sqr::FLOAT, + unary::strided::sqrt::FLOAT, + unary::strided::neg::FLOAT, + unary::strided::copy::FLOAT, + ]; + let f16_ckernels = [ + unary::contiguous::sin::HALF, + unary::contiguous::cos::HALF, + unary::contiguous::exp::HALF, + unary::contiguous::sqr::HALF, + unary::contiguous::sqrt::HALF, + unary::contiguous::neg::HALF, + unary::contiguous::copy::HALF, + ]; + let f16_skernels = [ + unary::strided::sin::HALF, + unary::strided::cos::HALF, + unary::strided::exp::HALF, + unary::strided::sqr::HALF, + unary::strided::sqrt::HALF, + unary::strided::neg::HALF, + unary::strided::copy::HALF, + ]; + let bf16_ckernels = [ + unary::contiguous::sin::BFLOAT, + unary::contiguous::cos::BFLOAT, + unary::contiguous::exp::BFLOAT, + unary::contiguous::sqr::BFLOAT, + unary::contiguous::sqrt::BFLOAT, + unary::contiguous::neg::BFLOAT, + unary::contiguous::copy::BFLOAT, + ]; + let bf16_skernels = [ + unary::strided::sin::BFLOAT, + unary::strided::cos::BFLOAT, + unary::strided::exp::BFLOAT, + unary::strided::sqr::BFLOAT, + unary::strided::sqrt::BFLOAT, + unary::strided::neg::BFLOAT, + unary::strided::copy::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_unary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [unary::contiguous::Kernel; 7], + strided: [unary::strided::Kernel; 7], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &mut output, + 0, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 37bc0bae..f18cdbb0 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -47,7 +47,7 @@ kernel void FN_NAME_STRIDED( \ return; \ } \ TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ output[thread_position_in_grid] = OUT_TYPENAME(FN); \ } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 83fbe833..e5c9fbae 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -112,7 +112,13 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { + #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } $( pub mod $name { use super::Kernel; @@ -124,7 +130,13 @@ macro_rules! ops{ } pub mod strided { + #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } $( pub mod $name { use super::Kernel; @@ -859,6 +871,30 @@ mod tests { assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); } + #[test] + fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|i| rand::random::()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); + } + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 52968bc2..2daac224 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -9,6 +9,7 @@ pub struct Embedding { impl Embedding { pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { + // todo!("Embedding {embeddings}"); Self { embeddings, hidden_size, diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 753770fb..24182b72 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -165,6 +165,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; + todo!("X {q}"); let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -295,6 +296,7 @@ impl Block { let residual = x; let x = self.rms_1.forward(x)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + todo!("---X {}", x); let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) @@ -327,6 +329,7 @@ impl Llama { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; + //println!("Embeddings {}", self.wte.embeddings()); for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; }