From 4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 11 Nov 2023 01:02:15 +0100 Subject: [PATCH 01/32] Starting to fix some tests. Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float. --- candle-core/src/metal_backend.rs | 702 ++++++++++++------ candle-examples/Cargo.toml | 1 + candle-metal-kernels/src/affine.metal | 18 + candle-metal-kernels/src/cast.metal | 18 +- candle-metal-kernels/src/indexing.metal | 9 +- candle-metal-kernels/src/lib.rs | 305 ++++---- candle-metal-kernels/src/reduce.metal | 156 ++-- candle-metal-kernels/src/ternary.metal | 3 + candle-metal-kernels/src/tests.rs | 158 +++- candle-metal-kernels/src/unary.metal | 48 +- .../{examples => tmp}/affine.rs | 1 + .../{examples => tmp}/binary.rs | 0 .../{examples => tmp}/cast.rs | 0 .../{examples => tmp}/unary.rs | 6 +- candle-nn/Cargo.toml | 2 + candle-nn/src/ops.rs | 40 + 16 files changed, 989 insertions(+), 478 deletions(-) rename candle-metal-kernels/{examples => tmp}/affine.rs (98%) rename candle-metal-kernels/{examples => tmp}/binary.rs (100%) rename candle-metal-kernels/{examples => tmp}/cast.rs (100%) rename candle-metal-kernels/{examples => tmp}/unary.rs (98%) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0b72f080..12f56d50 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; -use half::{bf16, f16}; +use half::f16; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, RwLock}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -36,7 +38,9 @@ impl From for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, + command_buffer: Arc>, kernels: Arc, + buffers: Arc>>>>, } impl std::fmt::Debug for MetalDevice { @@ -58,10 +62,48 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } + pub fn command_buffer(&self) -> std::sync::RwLockReadGuard { + self.command_buffer.try_read().unwrap() + } + + pub fn commit(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + _ => {} + } + } + + pub fn wait_until_completed(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + old.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { + old.wait_until_completed(); + } + _ => {} + } + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -70,16 +112,107 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + } + + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc { + let mut buffers = self.buffers.try_write().unwrap(); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + for sub in &mut *subbuffers { + if Arc::strong_count(sub) == 1 { + return sub.clone(); + } + } + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + new_buffer + } + + pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + } + + pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); + { + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + } + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed(); + real + } + + pub fn new_matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + size: NSUInteger, + type_id: u32, + dtype: DType, + ) -> Result<(Matrix, Arc)> { + let elem_count = (b * m * n) as usize; + let out_buffer = self.new_buffer(elem_count, dtype); + + let result_descriptor = + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + Ok((result_matrix, out_buffer)) + } + + pub fn capture>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(&self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + buffer: Arc, + matrices: Arc< + RwLock< + HashMap< + ( + NSUInteger, + NSUInteger, + NSUInteger, + bool, + NSUInteger, + NSUInteger, + u32, + ), + Matrix, + >, + >, + >, device: MetalDevice, dtype: DType, } @@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage { self.dtype ); } + + let buffer = self.device.new_buffer_managed(self.buffer.length()); + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + drop(command_buffer); + self.device.wait_until_completed(); + match self.dtype { - DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), - DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { - crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + let buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "affine_float", + DType::F16 => "affine_half", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_float_strided", + DType::F16 => "affine_half_strided", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; } - - let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_affine( - &device.device, - &command_buffer, - &device.kernels, - el, - &self.buffer, - &mut buffer, - mul as f32, - add as f32, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + Ok(Self::new(buffer, device.clone(), dtype)) } fn powf(&self, _: &Layout, _: f64) -> Result { @@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 - && layout.is_contiguous() - && layout.start_offset() == 0) + && layout.stride()[sum_dims[0]] == 1) { - crate::bail!("Non contiguous reduce op not supported yet"); + crate::bail!("Non last dim reduce op not supported yet"); } + let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + if dtype == DType::U32 { + crate::bail!("Implement return index reduce op"); + } + let buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage { src_el, dst_el, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { @@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U8, DType::U32) => "cast_u8_u32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn unary_impl(&self, layout: &Layout) -> Result { @@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + 0, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("unary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn binary_impl( @@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) { @@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, + ("add", DType::F16) => contiguous::add::HALF, + ("badd", DType::F16) => contiguous::add::HALF, + ("sub", DType::F16) => contiguous::sub::HALF, + ("bsub", DType::F16) => contiguous::sub::HALF, + ("mul", DType::F16) => contiguous::mul::HALF, + ("bmul", DType::F16) => contiguous::mul::HALF, + ("div", DType::F16) => contiguous::div::HALF, + ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage { el_count, &self.buffer, &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("badd", DType::F16) => strided::add::HALF, + ("bsub", DType::F16) => strided::sub::HALF, + ("bmul", DType::F16) => strided::mul::HALF, + ("bdiv", DType::F16) => strided::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( @@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage { kernel_name, lhs_l.dims(), &self.buffer, - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("binary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn where_cond( @@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_buffer(); + if t.dtype() != f.dtype() { + crate::bail!("Invalid ternary different dtypes for values"); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::F16) => "where_u8_f16", + (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + }; candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, - "where_u8_f32", - &dims, + name, + dims, &self.buffer, ( layout.stride(), @@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn conv1d( @@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let mut buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn index_add( @@ -561,11 +795,18 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::() as NSUInteger; - let elem_count = b * m * n; + let (type_id, size) = match self.dtype { + DType::F32 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 32, + core::mem::size_of::() as NSUInteger, + ), + DType::F16 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 16, + core::mem::size_of::() as NSUInteger, + ), + dtype => todo!("Dtype for matmul {dtype:?} is not supported"), + }; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); @@ -596,39 +837,30 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; - let b = b as NSUInteger; let m = m as NSUInteger; let n = n as NSUInteger; let k = k as NSUInteger; - let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) - } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) - }; - let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) - } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) - }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); + let left_matrix = self.matrix( + (b, m, k), + transpose_left, + size, + lhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let right_matrix = rhs.matrix( + (b, k, n), + transpose_right, + size, + rhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let (result_matrix, out_buffer) = + self.device + .new_matrix((b, m, n), size, type_id, self.dtype)?; - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let command_buffer = self.device.command_buffer(); let alpha = 1.0f64; let beta = 0.0f64; @@ -647,70 +879,112 @@ impl BackendStorage for MetalStorage { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; - matrix_multiplication.set_batch_size(b); - // Encode kernel to command buffer - let command_buffer = self.device.command_queue.new_command_buffer(); matrix_multiplication.encode_to_command_buffer( - command_buffer, + &command_buffer, &left_matrix, &right_matrix, &result_matrix, ); - command_buffer.commit(); - command_buffer.wait_until_completed(); + command_buffer.set_label("matmul"); + drop(command_buffer); + self.device.commit(); - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) + Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let src_shape = src_l.shape(); - let el_count = src_shape.elem_count(); - if el_count == 0 { - return Ok(()); + let command_buffer = self.device.command_buffer(); + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer( + &self.buffer, + src_offset, + dst.buffer(), + dst_offset, + self.buffer.length() - src_offset, + ); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, + DType::U8 => candle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &dst.buffer, + dst_offset * dst.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); } - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &self.device.device, - &command_buffer, - &self.device.kernels, - kernel_name, - src_l.dims(), - &self.buffer, - &src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, - dst_offset, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + drop(command_buffer); + self.device.commit(); Ok(()) } } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc, device: MetalDevice, dtype: DType) -> Self { + let matrices = Arc::new(RwLock::new(HashMap::new())); Self { buffer, device, dtype, + matrices, } } pub fn buffer(&self) -> &Buffer { &self.buffer } + + fn matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + transpose: bool, + size: NSUInteger, + offset: NSUInteger, + type_id: u32, + ) -> Result { + let key = (b, m, n, transpose, size, offset, type_id); + + let mut matrices = self.matrices.try_write().unwrap(); + if let Some(matrix) = matrices.get(&key) { + Ok(matrix.clone()) + } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) + } else { + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + matrices.insert(key, matrix.clone()); + Ok(matrix) + } + } } impl BackendDevice for MetalDevice { @@ -720,10 +994,14 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); + let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); + let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, + command_buffer, + buffers, kernels, }) } @@ -743,9 +1021,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype); + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { @@ -755,49 +1032,20 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as NSUInteger, - option, - ), + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }; - Ok(Self::Storage { - buffer, - device: self.clone(), - dtype: storage.dtype(), - }) + Ok(Self::Storage::new( + buffer.into(), + self.clone(), + storage.dtype(), + )) } fn rand_uniform( diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 38d26ead..adfa529e 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] +metal = ["candle/metal", "candle-nn/metal"] [[example]] name = "llama_multiprocess" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index e5f0a841..a08bfbc0 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -33,6 +33,24 @@ kernel void FN_NAME( \ const TYPENAME a = TYPENAME(add); \ output[id] = input[id] * m + a; \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ +} \ AFFINE(affine_float, float) AFFINE(affine_half, half) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index d1788253..4398e9d4 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,12 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ + output[tid] = RIGHT_TYPENAME(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint i [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (i >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ } \ -CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) +CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 444fa322..312b27c7 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -16,16 +16,16 @@ kernel void NAME( \ if (gid >= dst_size) { \ return; \ } \ - const size_t id_i = gid / right_size / left_size; \ + const size_t id_i = (gid / right_size) % ids_size; \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid % left_size; \ + const size_t left_rank_i = gid / right_size / ids_size; \ /* \ // Force prevent out of bounds indexing \ // since there doesn't seem to be a good way to force crash \ // No need to check for zero we're only allowing unsized. \ */ \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ output[gid] = input[src_i]; \ } @@ -75,6 +75,7 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) +INDEX_OP(is_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5a6bd41b..a0b852a4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -59,8 +59,8 @@ impl EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -111,13 +111,7 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -126,16 +120,18 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float"); + pub const HALF: Kernel = Kernel("copy_half"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } } pub mod strided { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -144,12 +140,20 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float_strided"); + pub const HALF: Kernel = Kernel("copy_half_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); } pub mod binary { ops!(add, sub, mul, div); @@ -161,8 +165,12 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), } impl From> for MetalKernelError { @@ -173,19 +181,22 @@ impl From> for MetalKernelError { type KernelMap = HashMap<&'static str, T>; type Libraries = HashMap; -type Functions = KernelMap; +type Pipelines = KernelMap; #[derive(Debug, Default)] pub struct Kernels { libraries: RwLock, - funcs: RwLock, + pipelines: RwLock, } impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + } } fn get_library_source(&self, source: Source) -> &'static str { @@ -218,22 +229,43 @@ impl Kernels { } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, ) -> Result { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + // let mut funcs = self.funcs.write()?; + // if let Some(func) = funcs.get(name) { + // Ok(func.clone()) + // } else { + // funcs.insert(name, func.clone()); + // Ok(func) + // } + } + + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + if let Some(pipeline) = pipelines.get(name) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let func = self.load_function(device, source, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert(name, pipeline.clone()); + + Ok(pipeline) } } } @@ -246,18 +278,9 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -279,18 +302,10 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -326,17 +341,9 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -363,17 +370,9 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -411,22 +410,52 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + set_params!(encoder, (length, (input, input_offset), output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + shape.len(), + shape, + input_strides, + (input, input_offset), + output + ) + ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); @@ -435,7 +464,6 @@ pub fn call_cast_contiguous( Ok(()) } -#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -444,24 +472,19 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -495,18 +518,9 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -542,21 +556,14 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -570,6 +577,45 @@ pub fn call_affine( } #[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -582,17 +628,9 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -634,17 +672,14 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index c6984474..867877fb 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,6 +1,8 @@ #include using namespace metal; +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 256; +constant int THREADGROUP_SIZE = 1024; -# define REDUCE(FN, NAME, TYPENAME) \ +# define REDUCE(FN, NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ constant size_t &el_to_sum_per_block, \ - device const TYPENAME *src, \ - device TYPENAME *dst, \ + device const T *src, \ + device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ - uint blockDim [[ threads_per_threadgroup ]] \ + uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ threadgroup float shared_memory[THREADGROUP_SIZE]; \ @@ -45,10 +47,10 @@ kernel void NAME( \ // TODO: Fast version for the contiguous case. \ // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = src[idx]; \ + T x = shared_memory[tid]; \ + T y = src[idx]; \ shared_memory[tid] = FN; \ - idx += blockDim; \ + idx += block_dim; \ } \ \ threadgroup_barrier(mem_flags::mem_none); \ @@ -56,10 +58,10 @@ kernel void NAME( \ /* \ // reduction in shared memory \ */ \ - for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = shared_memory[tid + s]; \ + T x = shared_memory[tid]; \ + T y = shared_memory[tid + s]; \ shared_memory[tid] = FN; \ } \ threadgroup_barrier(mem_flags::mem_none); \ @@ -68,72 +70,74 @@ kernel void NAME( \ dst[dst_id] = shared_memory[0]; \ } \ -kernel void softmax_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = -INFINITY; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - shared_memory[tid] = max(shared_memory[tid], src[idx]); - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); - } - threadgroup_barrier(mem_flags::mem_none); - } - - float max = shared_memory[0]; - - shared_memory[tid] = 0; - - // Restart - idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - const float val = exp(src[idx] - max); - dst[idx] = val; - shared_memory[tid] += val; - idx += blockDim; - } - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - const float inv_acc = 1/shared_memory[0]; - idx = start_idx + tid; - while (idx < stop_idx) { - dst[idx] *= inv_acc; - idx += blockDim; - } -} - REDUCE(x + y, fast_sum_float, float) REDUCE(x * y, fast_mul_float, float) REDUCE(max(x, y), fast_max_float, float) + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + while (idx < stop_idx) { \ + shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ + } \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float _max = shared_memory[0]; \ + \ + shared_memory[tid] = 0; \ + \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + const T val = T(exp(src[idx] - _max)); \ + dst[idx] = val; \ + shared_memory[tid] += val; \ + idx += block_dim; \ + } \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] += shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + const T inv_acc = T(1/shared_memory[0]); \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + dst[idx] *= inv_acc; \ + idx += block_dim; \ + } \ +} \ + +SOFTMAX(softmax_float, float) +SOFTMAX(softmax_half, half) +#if __METAL_VERSION__ >= 310 +SOFTMAX(softmax_bfloat, bfloat) +#endif diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0945b355..1f9cb38a 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -32,6 +32,9 @@ kernel void FN_NAME( \ device TYPENAME *out ,\ uint i [[ thread_position_in_grid ]] \ ) { \ + if (i >= numel){ \ + return; \ + } \ uint strided_i = get_strided_index(i, num_dims, dims, strides); \ uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 2330d48d..66dc8d01 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,5 +1,5 @@ use super::*; -use half::f16; +use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; fn new_buffer(device: &Device, data: &[T]) -> Buffer { @@ -23,13 +23,18 @@ fn approx_f16(v: Vec, digits: i32) -> Vec { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } +fn approx_bf16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -37,7 +42,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -53,7 +58,7 @@ fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -62,7 +67,7 @@ fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -81,7 +86,7 @@ fn run_strided( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -92,7 +97,7 @@ fn run_strided( &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); @@ -220,7 +225,9 @@ fn cast(v: &[T], name: &'static str) -> Vec { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::()) as u64; + let output = device.new_buffer(size, options); call_cast_contiguous( &device, @@ -229,7 +236,8 @@ fn cast(v: &[T], name: &'static str) -> Vec { name, v.len(), &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); @@ -245,11 +253,17 @@ fn cast_u32_f32() { assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32, 2.0, 3.0]; + let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec = cast(&input, "cast_f16_f32"); + assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec = cast(&input, "cast_f16_f32"); + assert_eq!(results.len(), 10_000); + assert_eq!(&results[..10], vec![1.0f32; 10]); + assert_eq!(results, vec![1.0f32; 10_000]); } fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { @@ -259,7 +273,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -267,9 +281,45 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { &device, command_buffer, &kernels, + "affine_float", size, &input, - &mut output, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) +} + +fn _run_affine_strided( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_float", + shape, + &input, + strides, + 0, + &output, mul as f32, add as f32, ) @@ -295,6 +345,16 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } +// #[test] +// fn affine_strided() { +// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; +// let mul = 1.5; +// let add = 1.1; +// let result = run_affine_(&input, mul, add); +// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + +// } + #[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; @@ -313,7 +373,26 @@ fn index_select() { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); +} +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(|x| f16::from_f32(x)) + .collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -321,7 +400,7 @@ fn index_select() { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } @@ -341,20 +420,26 @@ fn run_index_select( let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let name = match core::mem::size_of::() { + 4 => "is_u32_f32", + 2 => "is_u32_f16", + _ => unimplemented!(), + }; let kernels = Kernels::new(); call_index_select( &device, &command_buffer, &kernels, - "is_u32_f32", + name, shape, ids.len(), dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); @@ -451,7 +536,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); + let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); call_reduce_contiguous( &device, command_buffer, @@ -460,7 +545,8 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'sta let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -484,7 +570,7 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -536,6 +622,28 @@ fn softmax() { approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_half"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bfloat"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); } fn run_where_cond( @@ -571,7 +679,7 @@ fn run_where_cond( options, ); - let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -584,7 +692,7 @@ fn run_where_cond( (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index eb6424e8..88139af9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,4 +1,7 @@ #include +#include +# +using namespace metal; METAL_FUNC uint get_strided_index( uint idx, @@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index( template METAL_FUNC T sqr(T in){ return in * in; } template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return T(sign*y); +} template METAL_FUNC T id(T in){ return in; } +template METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } +template METAL_FUNC T gelu(T x){ + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast(0.044715) * x_cube; + T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); +} -using namespace metal; #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -64,8 +96,16 @@ UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) +UNARY_OP(gelu) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) @@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs similarity index 98% rename from candle-metal-kernels/examples/affine.rs rename to candle-metal-kernels/tmp/affine.rs index b8005dc0..cd019056 100644 --- a/candle-metal-kernels/examples/affine.rs +++ b/candle-metal-kernels/tmp/affine.rs @@ -50,6 +50,7 @@ fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { &device, command_buffer, &kernels, + "affine_float", v.len(), &input, &mut output, diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs similarity index 100% rename from candle-metal-kernels/examples/binary.rs rename to candle-metal-kernels/tmp/binary.rs diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs similarity index 100% rename from candle-metal-kernels/examples/cast.rs rename to candle-metal-kernels/tmp/cast.rs diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs similarity index 98% rename from candle-metal-kernels/examples/unary.rs rename to candle-metal-kernels/tmp/unary.rs index 7039c098..66cf25c0 100644 --- a/candle-metal-kernels/examples/unary.rs +++ b/candle-metal-kernels/tmp/unary.rs @@ -147,7 +147,7 @@ fn run_unary_bench( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, @@ -159,7 +159,7 @@ fn run_unary_bench( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; - for kernel_name in strided { + for kernel_name in &strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); let start = Instant::now(); @@ -187,7 +187,7 @@ fn run_unary_bench( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index d3f43c73..45298907 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +30,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..350bc663 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer(); + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_float", + DType::F16 => "softmax_half", + DType::BF16 => "softmax_bfloat", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let mut output = device.new_buffer(elem_count, storage.dtype()); + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + &kernels, + name, + elem_count, + last_dim, + storage.buffer(), + &mut output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result { From 2ca086939f91f5d8ccec745e47648f74fa520988 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 30 Nov 2023 11:40:39 +0100 Subject: [PATCH 02/32] Put back affine strided tests --- candle-metal-kernels/src/tests.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 66dc8d01..59f54fa9 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -295,7 +295,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { output.read_to_vec::(v.len()) } -fn _run_affine_strided( +fn run_affine_strided( v: &[T], shape: &[usize], strides: &[usize], @@ -314,7 +314,7 @@ fn _run_affine_strided( &device, command_buffer, &kernels, - "affine_float", + "affine_float_strided", shape, &input, strides, @@ -327,7 +327,8 @@ fn _run_affine_strided( command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + let len: usize = shape.iter().product(); + output.read_to_vec::(len) } #[test] @@ -345,15 +346,17 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } -// #[test] -// fn affine_strided() { -// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; -// let mul = 1.5; -// let add = 1.1; -// let result = run_affine_(&input, mul, add); -// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - -// } +#[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} #[test] fn index_select() { From 6e25822d4fcd3321f1e078706683b990780ba1ae Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 6 Dec 2023 09:59:44 -0500 Subject: [PATCH 03/32] Fix gelu for large x --- candle-metal-kernels/src/tests.rs | 23 +++++++++++++++++++++-- candle-metal-kernels/src/unary.metal | 11 ++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 59f54fa9..37b07167 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -205,6 +205,25 @@ fn cos_strided_random() { ); } +#[test] +fn gelu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; @@ -527,8 +546,8 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 88139af9..529162bd 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -42,9 +42,14 @@ template METAL_FUNC T erf(T in){ return T(sign*y); } -template METAL_FUNC T id(T in){ return in; } -template METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } -template METAL_FUNC T gelu(T x){ +template METAL_FUNC T id(T in) { return in; } +template METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } T x_sq = x * x; T x_cube = x_sq * x; T alpha = x + static_cast(0.044715) * x_cube; From 803ac8405b49fbfc4e5aacca6e70f7955386df39 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 30 Nov 2023 11:40:39 +0100 Subject: [PATCH 04/32] Put back affine strided tests Co-Authored-By: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> --- candle-metal-kernels/src/tests.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 66dc8d01..59f54fa9 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -295,7 +295,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { output.read_to_vec::(v.len()) } -fn _run_affine_strided( +fn run_affine_strided( v: &[T], shape: &[usize], strides: &[usize], @@ -314,7 +314,7 @@ fn _run_affine_strided( &device, command_buffer, &kernels, - "affine_float", + "affine_float_strided", shape, &input, strides, @@ -327,7 +327,8 @@ fn _run_affine_strided( command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + let len: usize = shape.iter().product(); + output.read_to_vec::(len) } #[test] @@ -345,15 +346,17 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } -// #[test] -// fn affine_strided() { -// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; -// let mul = 1.5; -// let add = 1.1; -// let result = run_affine_(&input, mul, add); -// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - -// } +#[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} #[test] fn index_select() { From 87dc559817db11f8d8c409cda959528e57e1db31 Mon Sep 17 00:00:00 2001 From: nicolas Date: Tue, 12 Dec 2023 17:41:56 +0100 Subject: [PATCH 05/32] Lots of updates including some stack of command buffers. --- candle-core/src/metal_backend.rs | 389 +++++++++++++++----- candle-core/src/tensor.rs | 2 +- candle-metal-kernels/src/affine.metal | 75 +++- candle-metal-kernels/src/lib.rs | 126 ++++++- candle-metal-kernels/src/reduce.metal | 2 +- candle-metal-kernels/src/unary.metal | 6 +- candle-nn/Cargo.toml | 3 +- candle-nn/src/ops.rs | 4 +- candle-transformers/Cargo.toml | 1 + candle-transformers/src/models/mixformer.rs | 46 ++- 10 files changed, 537 insertions(+), 117 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 12f56d50..4354422c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,7 +38,8 @@ impl From for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, - command_buffer: Arc>, + command_buffers: Arc>>, + command_buffer_index: Arc>, kernels: Arc, buffers: Arc>>>>, } @@ -70,38 +71,69 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> std::sync::RwLockReadGuard { - self.command_buffer.try_read().unwrap() - } - - pub fn commit(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; + pub fn command_buffer(&self) -> CommandBuffer { + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + if *index == n { + // todo!("Cycle buffers"); + for i in 0..n { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => { + // println!("Wait during cycling {i}"); + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + command_buffer.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer {i} not committed during cycling"); + } + } } - _ => {} + let new_buffers = (0..n) + .map(|i| { + // println!("Creating command buffer {i}"); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + *command_buffers = new_buffers; + *index = 0; + // println!("Reset"); } + // println!("Giving buffer {} / {n}", *index); + let out = &command_buffers[*index]; + assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); + *index += 1; + out.to_owned() } pub fn wait_until_completed(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - old.wait_until_completed(); + let command_buffers = self.command_buffers.try_write().unwrap(); + let index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + // for i in 0..*index { + // let command_buffer = &command_buffers[i]; + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + // } + for i in 0..*index { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => {} + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer not committed"); + } } - metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { - old.wait_until_completed(); - } - _ => {} + // println!("Wait {i}"); + command_buffer.wait_until_completed(); + // println!("Ok {i}"); + // command_buffer.wait_until_completed(); } - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; } pub fn kernels(&self) -> &Kernels { @@ -112,28 +144,40 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc { + pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc { + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc { + // println!("Creating new buffer {name}"); let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + // println!("Reusing tensor {size} {name}"); + // return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); + // subbuffers.push(new_buffer.clone()); + // println!("Created tensor {size} {name}"); + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(s) > 1) + .map(|s| Arc::clone(s)) + .collect(); + *subbuffers = newbuffers; + } + new_buffer } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -143,13 +187,20 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModeManaged, ); - let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); - { - let command = self.command_buffer(); - let blit = command.new_blit_command_encoder(); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.end_encoding(); - } + let real = self._new_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + ); + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + command.commit(); + real.did_modify_range(metal::NSRange::new(0, real.length())); + // println!("Command {:?}", command.status()); + + // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -169,7 +220,7 @@ impl MetalDevice { dtype: DType, ) -> Result<(Matrix, Arc)> { let elem_count = (b * m * n) as usize; - let out_buffer = self.new_buffer(elem_count, dtype); + let out_buffer = self.new_buffer(elem_count, dtype, "matrix"); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); @@ -241,13 +292,18 @@ impl BackendStorage for MetalStorage { self.dtype ); } - + self.device.wait_until_completed(); + self.buffer + .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); - let command_buffer = self.device.command_buffer(); - let blit = command_buffer.new_blit_command_encoder(); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - blit.end_encoding(); - drop(command_buffer); + { + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + } self.device.wait_until_completed(); match self.dtype { @@ -256,7 +312,11 @@ impl BackendStorage for MetalStorage { DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F32 => { + let vec = buffer.read_to_vec(length / size); + // println!("Got back {:?}", &vec[..1]); + Ok(CpuStorage::F32(vec)) + } DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -268,7 +328,7 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype); + let buffer = device.new_buffer(el, self.dtype, "affine"); let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { @@ -309,15 +369,111 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } - fn powf(&self, _: &Layout, _: f64) -> Result { - crate::bail!("powf metal") + fn powf(&self, layout: &Layout, pow: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "powf_float", + DType::F16 => "powf_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_float_strided", + DType::F16 => "powf_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } - fn elu(&self, _: &Layout, _: f64) -> Result { - crate::bail!("elu metal") + fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "elu_float", + DType::F16 => "elu_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_float_strided", + DType::F16 => "elu_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { @@ -365,7 +521,7 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "reduce"); let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, @@ -379,6 +535,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -391,7 +549,7 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, "todtype"); let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { @@ -435,6 +593,8 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -444,7 +604,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -463,6 +623,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -476,6 +637,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, + ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -534,8 +696,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("unary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -549,30 +711,31 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &B::KERNEL[..1] != "b" { use candle_metal_kernels::binary::contiguous; let kernel_name = match (B::KERNEL, dtype) { ("add", DType::F32) => contiguous::add::FLOAT, - ("badd", DType::F32) => contiguous::add::FLOAT, + // ("badd", DType::F32) => contiguous::add::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT, - ("bsub", DType::F32) => contiguous::sub::FLOAT, + //("bsub", DType::F32) => contiguous::sub::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT, - ("bmul", DType::F32) => contiguous::mul::FLOAT, + // ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, - ("bdiv", DType::F32) => contiguous::div::FLOAT, + // ("bdiv", DType::F32) => contiguous::div::FLOAT, ("add", DType::F16) => contiguous::add::HALF, - ("badd", DType::F16) => contiguous::add::HALF, + // ("badd", DType::F16) => contiguous::add::HALF, ("sub", DType::F16) => contiguous::sub::HALF, - ("bsub", DType::F16) => contiguous::sub::HALF, + // ("bsub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, - ("bmul", DType::F16) => contiguous::mul::HALF, + // ("bmul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - ("bdiv", DType::F16) => contiguous::div::HALF, + // ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -617,8 +780,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -635,7 +798,7 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype); + let buffer = self.device.new_buffer(el, dtype, "where"); let command_buffer = self.device.command_buffer(); if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); @@ -663,6 +826,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -752,7 +917,7 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "index_select"); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", @@ -772,6 +937,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -887,9 +1054,9 @@ impl BackendStorage for MetalStorage { &result_matrix, ); command_buffer.set_label("matmul"); - drop(command_buffer); - self.device.commit(); - + command_buffer.commit(); + out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); + // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } @@ -899,14 +1066,9 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer( - &self.buffer, - src_offset, - dst.buffer(), - dst_offset, - self.buffer.length() - src_offset, - ); + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -937,8 +1099,9 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + dst.buffer + .did_modify_range(metal::NSRange::new(0, dst.buffer.length())); Ok(()) } } @@ -968,22 +1131,22 @@ impl MetalStorage { ) -> Result { let key = (b, m, n, transpose, size, offset, type_id); - let mut matrices = self.matrices.try_write().unwrap(); - if let Some(matrix) = matrices.get(&key) { - Ok(matrix.clone()) + // let mut matrices = self.matrices.try_write().unwrap(); + // if let Some(matrix) = matrices.get(&key) { + // Ok(matrix.clone()) + // } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) } else { - let descriptor = if transpose { - MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) - } else { - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) - }; - let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - matrices.insert(key, matrix.clone()); - Ok(matrix) - } + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + // matrices.insert(key, matrix.clone()); + Ok(matrix) + // } } } @@ -991,16 +1154,28 @@ impl BackendDevice for MetalDevice { type Storage = MetalStorage; fn new(ordinal: usize) -> Result { + // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); + let n = 50; let command_queue = device.new_command_queue(); - let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); + + let command_buffers = (0..n) + .map(|_| { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + let command_buffers = Arc::new(RwLock::new(command_buffers)); + let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, - command_buffer, + command_buffers, + command_buffer_index, buffers, kernels, }) @@ -1021,7 +1196,21 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); + let command_buffer = self.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 87323a84..73a0cc7a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1864,7 +1864,7 @@ impl Tensor { } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Metal(storage), Device::Cpu) => { - println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + // println!("{storage:?} - {:?}", storage.to_cpu_storage()?); Storage::Cpu(storage.to_cpu_storage()?) } (Storage::Cuda(storage), Device::Cuda(cuda)) => { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a08bfbc0..18adb457 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -29,9 +29,7 @@ kernel void FN_NAME( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[id] * m + a; \ + output[id] = TYPENAME(float(input[id]) * mul + add); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -47,15 +45,80 @@ kernel void FN_NAME##_strided( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ + output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ +} + +#define POWF(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +} + +#define ELU(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[id]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ + AFFINE(affine_float, float) AFFINE(affine_half, half) +POWF(powf_float, float) +POWF(powf_half, half) +ELU(elu_float, float) +ELU(elu_half, half) #if __METAL_VERSION__ >= 310 AFFINE(affine_bfloat, bfloat); +POWF(powf_bfloat, bfloat); +ELU(elu_bfloat, bfloat); #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0b852a4..237bd858 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -153,7 +153,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { ops!(add, sub, mul, div); @@ -616,6 +616,130 @@ pub fn call_affine_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 867877fb..3a402427 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -18,7 +18,7 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 1024; +constant int THREADGROUP_SIZE = 2048; # define REDUCE(FN, NAME, T) \ kernel void NAME( \ diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 529162bd..765b14a5 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -69,7 +69,7 @@ kernel void FN_NAME( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -83,7 +83,7 @@ kernel void FN_NAME_STRIDED( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ } #define UNARY_OP(NAME) \ @@ -107,6 +107,7 @@ UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) +UNARY_OP(tanh) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -126,6 +127,7 @@ BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) +BFLOAT_UNARY_OP(tanh) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 45298907..03622752 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] @@ -30,4 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] -metal = ["candle/metal", "dep:candle-metal-kernels"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 350bc663..14dd10de 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype()); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, @@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); + command_buffer.commit(); + output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index af4e04b7..e72cab69 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..c8dae511 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,10 +142,9 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) + let sin = freqs.sin()?; + let cos = freqs.cos()?; + Ok(Self { sin, cos }) } fn apply_rotary_emb_qkv( @@ -273,6 +272,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let view = xs.to_string(); + if view.contains("NaN") { + panic!("NaN"); + } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self @@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_rotary() { + let dev = Device::new_metal(0).unwrap(); + for i in 0..10000 { + let dim = 8; + let max_seq_len = 12; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); + let t = Tensor::arange(0u32, max_seq_len as u32, &dev) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .reshape((max_seq_len, 1)) + .unwrap(); + let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 1.0); + let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let freqs = t.matmul(&inv_freq).unwrap(); + let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let sin = freqs.sin().unwrap().contiguous().unwrap(); + let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.099833414); + } + } +} From a9d06574320591f5cd966c8840237dc4a1e72ab3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 13 Dec 2023 12:09:20 +0100 Subject: [PATCH 06/32] Better version ? --- candle-core/src/metal_backend.rs | 68 ++++++++++++++------- candle-transformers/src/models/mixformer.rs | 9 +-- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 4354422c..f745342d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -96,6 +96,7 @@ impl MetalDevice { .map(|i| { // println!("Creating command buffer {i}"); let command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.set_label(&format!("num {i}")); command_buffer.enqueue(); command_buffer }) @@ -157,7 +158,7 @@ impl MetalDevice { for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { // println!("Reusing tensor {size} {name}"); - // return sub.clone(); + return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); @@ -177,7 +178,7 @@ impl MetalDevice { } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -185,19 +186,22 @@ impl MetalDevice { let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, size, - metal::MTLResourceOptions::StorageModeManaged, + metal::MTLResourceOptions::StorageModeShared, ); let real = self._new_buffer( size, metal::MTLResourceOptions::StorageModePrivate, "with_data", ); - let command = self.command_buffer(); - let blit = command.new_blit_command_encoder(); + let command_buffer = self.command_buffer(); + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.end_encoding(); - command.commit(); - real.did_modify_range(metal::NSRange::new(0, real.length())); + command_buffer.commit(); + drop(command_buffer); + // real.did_modify_range(metal::NSRange::new(0, real.length())); // println!("Command {:?}", command.status()); // self.commit(); @@ -220,15 +224,29 @@ impl MetalDevice { dtype: DType, ) -> Result<(Matrix, Arc)> { let elem_count = (b * m * n) as usize; - let out_buffer = self.new_buffer(elem_count, dtype, "matrix"); + let buffer = self.new_buffer(elem_count, dtype, "matrix"); + let command_buffer = self.command_buffer(); + command_buffer.set_label("zeros_matmul"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) + let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor) .ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; - Ok((result_matrix, out_buffer)) + Ok((result_matrix, buffer)) } pub fn capture>(&self, path: P) -> Result<()> { @@ -298,11 +316,13 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer_managed(self.buffer.length()); { let command_buffer = self.device.command_buffer(); + command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.end_encoding(); + command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); } self.device.wait_until_completed(); @@ -550,8 +570,9 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype"); + device.wait_until_completed(); let command_buffer = device.command_buffer(); - if layout.is_contiguous() { + if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", @@ -593,8 +614,10 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.set_label("to_dtype"); command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + device.wait_until_completed(); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -606,6 +629,7 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); + command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -695,7 +719,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.set_label("unary"); command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) @@ -962,7 +985,6 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { DType::F32 => ( metal::mps::MPS_FLOATBIT_ENCODING | 32, @@ -1028,9 +1050,11 @@ impl BackendStorage for MetalStorage { .new_matrix((b, m, n), size, type_id, self.dtype)?; let command_buffer = self.device.command_buffer(); + command_buffer.set_label("matmul"); let alpha = 1.0f64; - let beta = 0.0f64; + // let beta = f64::MIN; + let beta = 1.0; // Create kernel let matrix_multiplication = MatrixMultiplication::init( &self.device, @@ -1045,6 +1069,8 @@ impl BackendStorage for MetalStorage { .ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; + matrix_multiplication.set_batch_size(b); + matrix_multiplication.set_batch_start(0); // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( @@ -1053,7 +1079,6 @@ impl BackendStorage for MetalStorage { &right_matrix, &result_matrix, ); - command_buffer.set_label("matmul"); command_buffer.commit(); out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); @@ -1062,9 +1087,11 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer(); + // println!("Copy strided"); if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy_contiguous"); let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; @@ -1100,8 +1127,6 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_strided"); } command_buffer.commit(); - dst.buffer - .did_modify_range(metal::NSRange::new(0, dst.buffer.length())); Ok(()) } } @@ -1157,13 +1182,14 @@ impl BackendDevice for MetalDevice { // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); - let n = 50; + let n = 64; let command_queue = device.new_command_queue(); let command_buffers = (0..n) - .map(|_| { + .map(|i| { let command_buffer = command_queue.new_command_buffer().to_owned(); command_buffer.enqueue(); + command_buffer.set_label(&format!("num {i}")); command_buffer }) .collect(); @@ -1198,6 +1224,7 @@ impl BackendDevice for MetalDevice { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); let command_buffer = self.command_buffer(); + command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); blit.fill_buffer( &buffer, @@ -1208,7 +1235,6 @@ impl BackendDevice for MetalDevice { 0, ); blit.end_encoding(); - command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index c8dae511..8e16e6a9 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -144,6 +144,7 @@ impl RotaryEmbedding { let freqs = t.matmul(&inv_freq)?; let sin = freqs.sin()?; let cos = freqs.cos()?; + // todo!("{}", sin); Ok(Self { sin, cos }) } @@ -272,10 +273,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - let view = xs.to_string(); - if view.contains("NaN") { - panic!("NaN"); - } + // let view = xs.to_string(); + // if view.contains("NaN") { + // panic!("NaN"); + // } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self From 0404a3eb5bf646679d1fff1f59738c2b7d4b3fa5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 13 Dec 2023 16:21:48 +0100 Subject: [PATCH 07/32] Removed MPSMatrix entirely (buggy). --- candle-core/src/metal_backend.rs | 199 ++--------- candle-metal-kernels/src/lib.rs | 322 ++++++++++++++++-- .../src/libMetalFlashAttention.metallib | Bin 0 -> 102760 bytes 3 files changed, 319 insertions(+), 202 deletions(-) create mode 100644 candle-metal-kernels/src/libMetalFlashAttention.metallib diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f745342d..92c486d6 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,9 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use half::f16; use metal; -use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; @@ -115,7 +113,7 @@ impl MetalDevice { pub fn wait_until_completed(&self) { let command_buffers = self.command_buffers.try_write().unwrap(); let index = self.command_buffer_index.try_write().unwrap(); - let n = command_buffers.len(); + // let n = command_buffers.len(); // for i in 0..*index { // let command_buffer = &command_buffers[i]; // println!("Command {i} / {n}: {:?}", command_buffer.status()); @@ -216,39 +214,6 @@ impl MetalDevice { real } - pub fn new_matrix( - &self, - (b, m, n): (NSUInteger, NSUInteger, NSUInteger), - size: NSUInteger, - type_id: u32, - dtype: DType, - ) -> Result<(Matrix, Arc)> { - let elem_count = (b * m * n) as usize; - let buffer = self.new_buffer(elem_count, dtype, "matrix"); - let command_buffer = self.command_buffer(); - command_buffer.set_label("zeros_matmul"); - let blit = command_buffer.new_blit_command_encoder(); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.end_encoding(); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - - let result_descriptor = - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); - let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - Ok((result_matrix, buffer)) - } - pub fn capture>(&self, path: P) -> Result<()> { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); @@ -266,22 +231,6 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { buffer: Arc, - matrices: Arc< - RwLock< - HashMap< - ( - NSUInteger, - NSUInteger, - NSUInteger, - bool, - NSUInteger, - NSUInteger, - u32, - ), - Matrix, - >, - >, - >, device: MetalDevice, dtype: DType, } @@ -976,7 +925,6 @@ impl BackendStorage for MetalStorage { ) -> Result { crate::bail!("index_add metal") } - fn matmul( &self, rhs: &Self, @@ -985,104 +933,37 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { - DType::F32 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 32, - core::mem::size_of::() as NSUInteger, - ), - DType::F16 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 16, - core::mem::size_of::() as NSUInteger, - ), - dtype => todo!("Dtype for matmul {dtype:?} is not supported"), - }; - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, k, n (rhs) - let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { - false - } else if lhs_m1 == m && lhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } }; - let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { - false - } else if rhs_m1 == k && rhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - - let left_matrix = self.matrix( - (b, m, k), - transpose_left, - size, - lhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let right_matrix = rhs.matrix( - (b, k, n), - transpose_right, - size, - rhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let (result_matrix, out_buffer) = - self.device - .new_matrix((b, m, n), size, type_id, self.dtype)?; let command_buffer = self.device.command_buffer(); command_buffer.set_label("matmul"); - - let alpha = 1.0f64; - // let beta = f64::MIN; - let beta = 1.0; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - matrix_multiplication.set_batch_size(b); - matrix_multiplication.set_batch_start(0); - - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( + candle_metal_kernels::call_gemm( + &self.device.device, &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); + &self.device.kernels, + name, + (b, m, n, k), + &lhs_l.stride(), + lhs_l.start_offset(), + &self.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + // Create kernel command_buffer.commit(); - out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); - // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); - Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) + + Ok(Self::new(buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { @@ -1133,46 +1014,16 @@ impl BackendStorage for MetalStorage { impl MetalStorage { pub fn new(buffer: Arc, device: MetalDevice, dtype: DType) -> Self { - let matrices = Arc::new(RwLock::new(HashMap::new())); Self { buffer, device, dtype, - matrices, } } pub fn buffer(&self) -> &Buffer { &self.buffer } - - fn matrix( - &self, - (b, m, n): (NSUInteger, NSUInteger, NSUInteger), - transpose: bool, - size: NSUInteger, - offset: NSUInteger, - type_id: u32, - ) -> Result { - let key = (b, m, n, transpose, size, offset, type_id); - - // let mut matrices = self.matrices.try_write().unwrap(); - // if let Some(matrix) = matrices.get(&key) { - // Ok(matrix.clone()) - // } else { - let descriptor = if transpose { - MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) - } else { - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) - }; - let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - // matrices.insert(key, matrix.clone()); - Ok(matrix) - // } - } } impl BackendDevice for MetalDevice { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 237bd858..b80dcb79 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, Library, MTLSize, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; @@ -105,6 +106,7 @@ pub enum Source { Ternary, Cast, Reduce, + Mfa, } macro_rules! ops{ @@ -179,9 +181,8 @@ impl From> for MetalKernelError { } } -type KernelMap = HashMap<&'static str, T>; type Libraries = HashMap; -type Pipelines = KernelMap; +type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; #[derive(Debug, Default)] pub struct Kernels { @@ -208,9 +209,9 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Mfa => panic!("Invalid lib"), } } - pub fn load_library( &self, device: &Device, @@ -220,10 +221,20 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device + .new_library_with_data(source_data) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } @@ -234,19 +245,41 @@ impl Kernels { device: &Device, source: Source, name: &'static str, + constants: Option, ) -> Result { let func = self .load_library(device, source)? - .get_function(name, None) + .get_function(name, constants) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) - // let mut funcs = self.funcs.write()?; - // if let Some(func) = funcs.get(name) { - // Ok(func.clone()) - // } else { - // funcs.insert(name, func.clone()); - // Ok(func) - // } + } + + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) + } else { + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) + } } pub fn load_pipeline( @@ -255,18 +288,7 @@ impl Kernels { source: Source, name: &'static str, ) -> Result { - let mut pipelines = self.pipelines.write()?; - if let Some(pipeline) = pipelines.get(name) { - Ok(pipeline.clone()) - } else { - let func = self.load_function(device, source, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; - pipelines.insert(name, pipeline.clone()); - - Ok(pipeline) - } + self.load_pipeline_with_constants(device, source, name, None) } } @@ -830,5 +852,249 @@ pub fn call_index_select( Ok(()) } +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; + let d_trans = false; + let alpha = 1.0f32; + let beta = 0.0f32; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let m_simd = 16; + let n_simd = 16; + let k_simd = 16; + let m_splits = 2; + let n_splits = 2; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + // println!("Constants {constants:?}"); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + // TODO adapt for f16 + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + // println!("Threadgroup {block_bytes}"); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let mut buffer: Vec = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib new file mode 100644 index 0000000000000000000000000000000000000000..f5116ca6a34d35f01e031cc006a2ab1b2ce86964 GIT binary patch literal 102760 zcmeFa3tSUd-Zwszgd{)+;cB=#0aOsg2_P5oG9gG&(FUb0UMdMUK>@=>MC-P>a8t33 z7AznSm(%{jkwKXcBR?<~v7O2aiQjE7-b4-CUf_Tcc3VY&1_hQ+cR2dA9!J&xmi z7hHfo1pSX;hn(l8636vO?J4*zmn~Z)zzQ~1)NU%KH?QR%X$ZdOf!lr|Ci4rehB+@hn30hqb_Vft@$jM*EBP1;bHUi) zxGa6;YCdN3cj&Fi$jaekPHVA6%hPjqfCDmziQ) zn7SdoF>&j|xi0IQp62Y)zlPK{5*U7T&r!>qSl2s3q{zisWwi}3PaGP>&4ehFVqdK{U9pF zVT{rlHzuP-x_{Xi^Xub{@p5JsGG;~A${Z24M(rp+9ukM%Lk}%Z7h%$OwC}rW7N}Z_ z{+=?jhRa9#C@o=FFs<1ihF?@a1rp;=&ta5Wv)mCgyrNf&%RrS`2g1oX# zs={h)&IcH_A}?Q6u`zE`adCA~4Hk6;ESQLVf???skxwzKTvfdh8Kc9XVe0MB)P?UyH{! zHXcP?=TlJLux#iz4!%s6sE`M4u$&x+J()8*-uDV^|V=Q5$ku z8Qh}{0avV=&#P@~s>1lyh9-4M6ZlzI*HokCZHwe>#gGn8473MdSA|~2LWkOe&w*F% z!I#a!_V$n-ZScLzx70NyX7aX>O zK`_IkAAB={8KE}Z$lKb9Vu8$nV?M7=1F>kTp6u0iZD@}w_``M&yDaomZRiD6@bF;B zP`gKOduYEU_<}4H1tgn#0XMTXoyZBZZ@8;rBM~wPYYU5tVJj-puLCSL&Ym+CNtn*sgltDhXU!L5|PHmM3+2qL`?9^kH zB2!cm3rTMX&2DU&><-#YIYKCe%n(t^B7 zIi^i9%Tw#wDfce_YD8FW5Pqc;-kM;8O+W9FJO!&yJuMIFl+QoL27l()>ywe6Y^89} zEWD)~ znJp~GJhv%%HAY@71vQyhW8;1y;%(EQFQjvwk%zuCSZm}}DtX&PyjnOVc~wSUHNo2k z2XH#khwBqvuTONn4L*$X>NbHwAM@(=#A|O;$bd21Zcp;cKzl_Xol5w(al!K{Fn-*B zxze-JfUK=jAZ7FqY6NdH#AVz+sBMO;jJ$jQv~OzKyh@B$MImp|DHI_(2=EQej_yE7 zt!7t?oARKk$>eRzL=8xdt*Oe;_%1S0F;R zYG_0rF@WoNd<}$0AwVPXzFWw}1UP}tz(#q8?`r50QP&DO7#5Mp_x(zOVSOwd3xe~w z&^B>^SC*m0b@~J6IC54z7Ez34U^_LL-~Tk#SLW)QmeTql?&h1WZSCPx0lrmK6wkMs zO5wHEQLhMnzoNbn`hHDm69ZdIUcyfXig9BKiz?x0yoDd!x~h=Qj`Fp)4dv?S8>KF& zM+YIyc)ljMcQClv7oXoOPp-GNq-l6t>v+{|l>(@qjusm%eIB%m{Ug4$O5S57yd=~K@{%-7YUmJ^yxK@!?Hw2Jc8jJd z59+nJsj97MYnwI(+NH1T!tcz`*s}Bcpi62G?x+8+sDdvzdRrCuLj$yI@{~6EzuTJJ z>S(mBvmwyCy3|Y18e0W#l{Tq!o2u0MniAB2)QFpEj73BIBYy<$gQ;SD^rSey+l@Z1CX)P6;TGa|~SEp@!Qw(P7?gC|OX z&<%8dCkgq~67o2Xdh>os$cT!Y)1Xi8x2Brp^C3a*_RJfSr#8c>wdJfE#I$J(*aO`$ zj5pe}Yd?d|#Vow05Z)rgEf-~>LxWQW217mt12O|rP5SxA^g+-Xx6u7dw*|H8m)P_A zGmllmFhj|!gapiO3i$Y(@TSu9OFMT2b%LJP2+!LJVTDNe87YL*TXrtA^Uwp{A;YiU zRM7*K!6`S<7$&4w5&}6mESXwNya4U<{BAb1@G1Sdry=(tvfIGU%%&}BNWCWEHRv*E zYriIiw+yKJi;pDvO*IBy5(y)yrb=T|O&xE`ody{BZ0UrAq9<^+5-|3H8WQ&$-pu(UqA5>kVs;HTa*OXKhsS4?v@t-*r)zxgS zDjM?wDdugdDuq#g%_OI4Rk^whLX)XM5an=BTt?J73?$;EBq@zeOiUsoX@}yXSz}Oa zT+}k8jI1o!RIV;9D~c^Fs@POjUY%E1TCK`2qoc(*6|1Ui^2${;C3!^^KXa>6l@~c8 zomX50lM_yz{7sw6ocHFF=PAZdQQVuIm@qLh`Br3eb;+ixnhBE<6Xqi(Ohq_mAsnHi z35ZSQ6{YG;XxfKBxwf(tB^n`*jwW)_WHM-d44&emxuu&hc4NaF$DTl)=L9 zp}@Bya29}b0H~IeIuSpd-wVea3xooyI0W1Qs%lW(0ICt-9N0o+RggL&E1b`oW3{<_ zDzMJbSOjvhAfL2~xb@Np6KkNZ=Bh z^KdB0hR~e9An7tSCwD43pPKy9RM2yYOkN!c=XjFUq}50ElgWdGTfE>ZA$*b4lWrY~ zPE4I|t_*$)zB}-@c4)%r&XzDGs<{;WMT1a|Lt@oS54J1?q)DMt899ayg(&OWVGQFmpZ$h)_*eEwl5!_5w z6Fq9ZnyVtO5Js7}RW$4p7U`L8oG()K8pUZML1@ab-bZ z-9S*ih{L{)pfu1@j+n&`Xu9VCCp1bqwTeA`1SyVk0 z+*~=ZB3PR&`pk5YVpS67P&DXxOb5xt1A!4sFzgaHOcc~iRS-SRGHE>_p0{CyYxU3t z)XQ9S>xlerOC}p?_}~|fbNKac%^?-JmkB{*5I2dh_hy};dYfeuxQN`MXc7CS@pUjp zvpJ!y8YlD;kHhVy)A+u&BIDqrA{+xZ9}W&SL)!#coKO^2FR=6O5IZ!YQ+z-$tOy2q z?|}4Xx0axMLjx`}SB_k>;D_Rs3qSNh(hGhg0HVwa20bW>m@7y$(UJfXeIk)3&JWT@ zpGLSQbeu$Yw;Te)IslIWoRJ9UM_`x@aBv=z435GM0GyEsNB%H2FcRSy_=7&-4na#= z--#id`7wBg6JsQUmjaIRn~@BTmTZt8jHGcW*lz$vz>kp(ei5)R_%V{fM*$0gA0rvu z3*s9LI3pQ+7U0MqM$$N(KL|LolaY-6X24NiFcRSyWWxV##qGGW9qNk!^VnJ_|` zq(agviIaym6SbuAdHE=;38SG2DpIEkd2i@6(VcX75jTFAHAy251(N6(CrxmgL&66+ zCOOA$6I@|LNn1zW8(U3qM^{T}4LX+N*QO@uNs(4rS#5dj`m#;=s>!k6q zWGsoofRTMoX+d6TMQIH>Ep{B1(k51ot6Ssevvafm^?Y_cROJ8Ee0H#7K6_|vKD+AY z^V#9!=Ce~A^VyYPTU=P5V?H|^&1Z)@=d+)L`Rt)A(Qr>N9@h+W#W1a{BlyM;6I?CH>{MJEC5!KmXG-7((~8|MqvC&BcHJy_BU%L&h~ z6`o&F!fO;vmJ7d7dVT@NHtsEv=eN{^33yEkn(Ko3drnG+b)I9|9?jp&=bGd}El_g7 zl|43#uM)z%`RPA>w$ zT+a1hUIft2y(CXH(~B5rHDmnp1)BJ`3kU80Wc{I^p7CmB!}7%fnCU-loky=@KrMus z`?(n0$b84FMWCA+%yR@{p%yaBDb9)iUNmRiVx5cDL#$|05M~x(e(sny`I6o-br%Ey z!5RfthlyaJ=9Wl!9n$1J%Nm#Hbbz_EHuik5#AHQF8DlFdWyDS_Gf~utMFce$V0%3{ zY1nygSd^xa-7coy(B_JWk0{cO;084pUOF5z7n` zlw$rGVWRkbJcc@MG5!cnPhcNnhD726M$KDMP9DZ`H5>GmUS?KK5dS7};VC^5MIX^R znqyiU$1X?|Kaa(j1ImbGwup_oC$!Fi#*Aa@9IQ0obJ~ataVny%xW_)vRxjE6vMWAE z)_K^K<+4m1_Qi9gGAmxzcg&S_e3>W`-8lZbE;6<{^545I;^CZ@b@WUIs2TF)!*xB9 zkRgi)S`F2PUQ`7?E(y5;Qwy>$a7kKPA1rpK)o1Xn{!KEnU&yy(2mKpSxqv&M0lYHU zc&P{7>B9$BKITnyeUHV9JWis)Gkh#Ue&{(WSFIXbRP0f!50!@Sey_J<6u42r6mOzGK@P%a?Z$DVJSuNJVnhk~) zX3}FrOtzk}g&Ct^aKh!^Wo5*KNd?TpOkc%x*~td*i2!RM^jE@NFU%6a9S}}?emP{Z zqCy3Rpqpo1h~kG#U6s4uvBH9EYw~VgBZhdu1rT(Cyv=x}2Yf<_!$~#H>ET@S=i-Xin`nCE{yhcB503GPhNi%DF?29+bbQM}IYz#aMJEDy)x z*~<0?2_GXI<>9rzU~c0szdXNQk)99DlOsoLcT~M$-qx{XxxGo^yGP7ra->kFEsQ)x z!d;ys&G|g`c}0zC>~cdO*WC!EWh~3)3KAWmJ4>TU9HW8KIF{qJ(t!!Z>7xSE1oPa^ zE7{{qWAD>4RpiNg21LfNFXUFK#4OB_t%^sv4edicBx3+^Mi5zWz5&vQiGF}@<6d{` zm;uG{#+GSQ8YP&9fE2cGO-py5Dbd_<^`3pW-YnU>le>8!)b~2+ay<5P-=ewt2k4!5Zm#A2`=ZR=mx1ffS{ffZwY zB%5R@e*B5INWp4Q2(PWwuw1fZuwU0>e_fCLKY2Yi2ddhn_1Lwz@3+uwPFRzT z-0QRGk#tff%Bh!g-qVWTJ-F&@MwwC}$XAL!@#0u2gBQuDmW$#>EunZnq4C>!a;USv zuHF9U*KW~dmSag59l?fjdV4v{a}^fVA>`hQG&)li7Q$+EHQIAtKrb7ss@BunUB{Nk z#{ifaTVL^WxB`}tOH^gWv~O@RI+vx%(do}7`p~*Ey-#&)+k5S%(!%r!>*;6ri%3mu!db#TIbx=UOmy9vFSo|No{F$DJ+7IVR6Q;R^+X(DlJ5p$FDCfM0#jnx?|4(ofnWJdE+j1ERt83 z7SfsIaByTAQXIRsOjS{oS6aAi`BLYWd*^O@=SKTtcnAQt2Eaypw8!2t(|gbS@srB; zCZH$GRy($QJ9f@HHqA5p$TMNfVD?g0{{0lhlwDNn}dExx07q$?=b)-%`VBx1x0BB)H0Y?oQhe2c1 zGn~IDUZ*s{m0hs-pOFfV{e0aZ_=v-*zxZp|AOL#_xbpx%&)0c@J|3r2_QACz=$`{z zpaRcq0F3M=#NWY=fx>Wp<9xsez%tOcNOej(I4i-s@4?v@oELyI#JdYQgLfkkycz`$ z#+DFL9cJs{vlQ;}g4adn9)UcD{|0F#-PyDH$$pj$;G+s#kH8uuK91lf8m@AW=bRT< zqwQ(2{Z!;BgZoKRZt542@tk)AW`QC$`N7y;mhV~JsN4I|ihY~k!OrS#^`^y8*)f@~ zlV?~Dtr`&6qLcp|Jx*!gJLE{uxp(Ztx<>zZUvlrl7IjgZ_b0I<+~0i~ot+}i^o+XY z6#XV%aK@yy2)tuUHV)4FqsUz)I-?tPe}A?jU%4B$LnN0Kl5Dnc^_o^oxUfsaZ4Jm> z4~7ea0{e+~bhrKAl_T`M=*&0B^Q?W`{ek_;fu2A`new89Qzi4gW*|q)TP1Y*gSwTZehb4HZC#7Fuln2^mF$b)IMRPl3A*T`4 zX)WgcQ7w%3{j2E{Z8A8yq?Xm$V5_YB8{t+XZv;7gW8ggoi~6Ko>KP=qS>dWn)^ohu z$mwsU@9+i4!v$rJ8<9$v=f_h=kjkhvNTXW_PG=XMmPKlM0jlo}YSv{UDD$)>(tx1W z-XMc+F@hGC9n^S3uzS2WDcxocg5S~KaDQ(`(0ksSG;Xm7n%#Afm4$-KDtndi0VgwT z;y7dvjF(zXXU=8ylV)yNf3JYG>{72l9(@`i{pS!O52a5oWt78--r;iPTzP-`D;$Db zLU3V^$GxwH8$1J@I`)3t!$HRkJ42dbTZq$@K6M~Sd2GiBgKNcO8%5wC7k0eDUXcWu zPlA>kCiAJ#OsUX#!~bEk+hNM52L8=%FDu<%h6DJgzsQL`b^O0NPU zIMav25)=QB_{@oykcgZthvq;0N{x1*5-aLHs?bos@V!wQ+ywyIV0&N=!cxeTLaro?#RS~B>od3?ljIeoamf}Y1gqt8G@Y>jzx2!p$&GCz+Myf zk}1x;CdS%!$MzHUP6VO7Ce7Os1lv#a!&6$EklEsdB7<9h6fpYy;*S7olan*^ieMyMff1h21iNa7yoE%4AMP{>Az@{@nc11 z6W;El#jdED1NGp*Qjc@M{k51N$*4`^ZZ_jN7;o+;BFbvQvuXISER9=a!*eJEX93&} zux&M;bG(6qIBxQS1fAVY+S z(F}bWl#?}FqHXsqxv2l#ei^^G;BAs~w5y2FeKZi~Yk?xyLKnF>C~{>{2FGy(I&O6w z$G~yG0cGL|K{gq*h8xwndzRJr(z*Rww+#hvD_o9u9W%%Gaa)IH^xk`?vA|AuOfjiPv@H^~*K0~&!HyRGHM<2#vA4$&2<#Xum3eq`aqmWt$QoDrI{)*IYBelnn+I-|2 zQqy?;FhCBU+t}aRC`r0nFg;Q5t=H7}E>-6`bUZ49W68j759>zg-?4bFSCu2IsQz>l z44pejpKsi_vyDAj^)tn%DYs*&1s2t_#r<*6jy4k90hi|nf(8XFx5VpLu%iY1p3pK4 z!I`r5_EmJ&D0ZF=qq8Lej#^K~JHdXTn|I~bC0>S6z8$C-! zqC9>t`0cOtWWC|8?VQJdb(Ctm$WAr8-K2JTev3uh*+(R~(Rb|JnKt$j#fbip+4ESV z@-0zTyJu$t=PFBnNz@lC59PFzjV8B0QHLa=DdHC!sq5mu6XK?$0@*HG<**?lJz8@tjv0skKVJJ2#+{Vx4lgiYr8{xV9p~7S`D)=wNwV7e6d9HZAF= zW{TKt1!C5t#79{DV52d{tlTe))(tjR5$p0=u z47N^5aF14cz7>PV=PL6NT$DPSBRlZiT zU!pU(Cv?q9^HC_hGWt1OlncTDxa8v~(h{uC6mOFDiCQ>I`@J)d4{-!Z9RGm+xOuBc zg=E;{Sf>9dTw&btrLV!R`vMwqT!%fRaRIF$(8|nK|K7Nx-0Qzxg~jt z73iZCSMV>g1(D04)GWLCIp*`Cm#vMT?ozIuBF^7$4D1w{n|8kDuxl;#sGFc*|KJ^br-;>BNsPMC4mt9B4QEQumH`bv zkkk3uY-P8^)(3^-4>~V4+f_yy^x|)x`mt@AtuZLGARX*d-C~9LX}is$gz1ghcxG?2 z7VSDUTXqX9F08hK?p?}+=7&%r7d$F48NawMz1BTP1GU1@)ylj3t(mXj5 z7Z(y23g*8Y)JxrYIY@vDV$&aftZ{mVzzyeE%pR+Yx;>9Rz;f44Ax$1Ss0@b^!v*kK~RCM$#B1`H4jzA7-_dp+;ARjAqR zf|r9LEP~~UW(wtd__AJ;I?aGsYVE+MMbBdkJRg>WUqYD@(9~g4}FxSTX2E1 zD7Ogmm4hN53`z^l`yQ&y5VUYP7D{mLZ2GBHtWCxj1$W!upIQ!b?TepH-F^KmSyF_u zPmmqv-hai(6ce*e?8W?asrs@jK$W_Rag^nR2h?4IlaA z5A84aTAlupzh%+hRk{RmC&9iW(FdkBY2D)NQ{y_js*T-}mceG_^5CD^ijvv2{N@_e zdME;*^?+qq{xE&23?P$xL|bwzp2Q#2RRt9Czl8_Ow>avodd3&th8s|4kLjj}uZrdU z-bO3s5}?Tr=>1EIAjg0U_sJf6!Adm@78j9upHpYTMWacmE4wQYUr{!%YMi5BTXYEq zEvn@`Q~JG4TCWO%Jq0ypQhu@PfYKNyzeMG-mrT_XAL7tLQ{`qmSKc4jEeqoNo3#RH zcgnqN&_36(!*5VrXcGHhhl(6}VT~+@pdmN*;dPK-qpk<#S}NQmvO1yMC*AI1^fP z7CfG~s&Ro?uc^o+AhwSS|F$;~zvhXWjW^lcRymYb(7vlQrfE`BbJ+__*m@FL!keFu z%&znaRzaH2x3P!apt%`w@y+cGmS5o17M$hKY0in`N1B>b8pB4OT|<{m7OUTTL_6zQ zM>1a`x;3YK#pf1%5Ryo+aJ<;{j`7vwTIhiV(5gy2zarX-GF`GG^;$Q}-$slcX%^j^ z77b+W3y`&Onx3h34?xeq+>uFjBIs+X1uCEJ(W@cVsn#6Gev1u=@gM0YoJZdu|%(vSJ!!XC!2&jBr`(3zw%7 zE>b~yyY~RZPD$ujakDcUAW12!()Ux$ujCm=@i!%X=_9U8MfExIfCIwje(Ph4aZPtQ zqesJ>#VE41ZW*Q*(u{^Vy`~rR)@Cmkx!|avQ|{ON=$x40S@%pE0|;IT+fjqwl$sVLY~9&-PJPTOU-Y+o5Nu=~Yl&zRQ?x+z~8D z5_NK*ydE*R;YV62MJf*E)waY@Ue~F5%@p(wL&|3qU$hnnY#$8HX;`rY72!u%P=r^h z8#Egra8!i9tIohb@j*rSIEP!a(ouvqbL zRGQDnVL7;qzw|A28XmGrYB8gZQs(X`!uIxiMYvDL6k+qvig023s|hf4#5|w~5BjZ^ zh$}MnwaCDAW|ZUhcg%j5?fx<9aAd*9Q5qRE%I)SZ1y5fc6UR;OL69c&s8-w8vy4D?Gsp zt$(pZ+$nnbQF6yBS{j=iQuChDABsRLvhLBo26KH*57e$oo#gh@%`a7Y=WN)onG$)N zlfPd%1qODqDG_Vo@*C6?bAq{H$vRgk4_aux)a_GrP(icwYfISch)=wmp@RNSAKJB2 zSwBk!Nl1)1AnVk#E2%f8iTp;q*THabx}#?kTWa03X&;KLwL+oK!@dG(+2_>A`?=X< zFPX0)2)_T1GPm-F1aR9eLT~m^4aP&3m=Q~? zq70Q9(IZOg+iB>|fc3^H=nbTwn3l`BX6aHn$JU_!V#rTyd)yToWs~Me1ytXxB+jjz zvR!IV*){e;{FAXTI$6=2);v^zW-fyPoAPPu_t(V(=_51d&bD@Hn!VU^Oh#(-AIIt3 zIpkZ0*?8_d(wn=e73q16vJn@))dMfnLg$ctH%-9~r@Ns%w{xP=!h>|jB8Rp}Sit|y z)fERhRVa_i-`iuLB_%jVIfJC4GUZQ%hP|&%WoSlm!q!12P71GzS|n{t9`*C0$8nP1 z>ey`OB}0lGVj8skRwu20(zKv zqZgf)?abUUK-&@!&b8q%=i)Z= z#3PRt#ku5p>CKN=)@TP{ekBgaZLhfY68g|SzjYpQ$<_NFdEa9@(P61!M|4{J#zu@b z>7W@q%Hxdz!%}7l*?unT?PKd+PM!rrylC4Yzo0~7-E>0Z#kRtT^^QJhWmj8(GcKZ| z7oU4}m9APmOt7CLdAIXjZtL_~VVv2uS~L4$&Fo}jjt`8$v%mMoK6qB;EPEZ=04RHn z0lX4Zr!_`Dx5IiN*N`t(YWw<7tFe5r5bozTn^@+k%KUyvMTob}ei;(*JNsk&8Mc-A^0tQOfTJ6tUGBc%CCKHQ3{?d#=`pV!6Pqm!k3 zlTd#9Zfwlf`oKK}hHf|xCers5gZ!(9#Ea>&dC2d;L16{=iZ2U(946UrqS}i#B;YS(*JKsC++WW#j$#SyZMd#?k^JBW-ya>avG%9Ka zjCU-DP>ot(mQyk4RCZX44<4YPGx|sH$kKOMS5q9frZ*pX-estYm5o0+=+bC>I=m&N z(aU|P%pf7UraV8b83wJ|*V%Qs?8n-QOh8@{3*ogn!h2P@-xyjPa8nx^5#0;<1TMRf z%a^`J<$il8G=&4gA=RrqcbhX4A5qss-ObgxcP`bc)@FZ#?o&j@y-9i)UZRmqd+2pZ z=gfB@TS+vsXy$E)lEC>+IIQ{u&;5qelWfxq@&NwrTDm;IJqX+rx&LI>?{$BzE!HJb z;K-=KsoOPGmS#LmL*(9R3hfM*axHX~&Ar15HLZoYNMD?l!`=WbTt8|eujuE){l8@akdErKrpF6IsQ+nO7)g;_r-@sp_E7+5ZIjqQwe z;R@En@SQUax;Q9Elz^J{SGY?109^Hx-4miummSNeuRMC$>uc0b5?hpHBf05|#7{{bhI`1c2Y|2s5s zbYs3%ELV;msYmlPaLB8VgM-P!Uki zB@xO6*bXmtIJ7rL{T2O5JL~3?;z&=Jm6^q9S5&Xv-+hHC_Fo?iz41<53upCV|1(Mg zIxXF&e{161|dOAYk6_sXp9)9BtMO^p{n zI96XRcO}2CJ#=O%)O*zWnz?`byn4YSqJ>F(%bz{sk{&J{l+1~-uF9oPV%e!A&h!J^|&%>^;NF7R*IA!}h6S zQ6v}A9&Uc`oE&h^6tcCQny*J#=q|69YOw!gS)4Od`nJ(KN}D^fg`RXF9Fr~v&UzU8 zE_Dp_bI7G8?}Glp?kvay?a+j^+F3Wy zERq!4*N3hOg{uzNFBz*Ch2mZ%riZ95HJ)vGZx%99GAv|UJCzI2lpTyY?SzNkkp-_z z>6`jpKAlH*NV3K4+i<36*hw)x22gWOKa%x?qZ|H8=k7)K%~O=OBcPG=W$|^|pqnbq zfxAa;$2wb0Jm{`U`S|rmM4v-0eD87U%|=TMzS^lAP@iwF@-y;VB%B+M;4^7&wWH9x z=ey_~y}M!_mnFM}Qe1Cr!l6>1=45evpVCtO-j`q;ZU#)W*qQ9&KLrX{Z@fhjGcgM` zo9vB|Wx=p~dXzbRj?Ym+yu_#3C0wjDqfrrk3;jZx_W@NitBIm}gontbl+VG#N?5ki zqB<+h_JKNUe_EH7m7ub?a~;zV3Z46_OSQN)0F9kpWBbremx!&{;lYJgM}rzATW=Q1 zv8ksyA^n^)g(xv!yea8C)kM5TrhgFT4QCiHNdShsUfG$B-pady}iP_hBW8};2z1fyPZXL|pISY^m3eby+Fvq0iR2W^& z&At$f41#S3Z>xP$y&xIBin8~1`1oH@Y3^5qbT^$v=#C2FkI#hL8~yU3bn=SLO+`Us z%+vYwc^h-%u4-Gvj0Ujhwez^I zij|aL25kRX;}%A*)e0U(%T**=fwFVa3e}3hF{B z83n4@LqVWYZ?ul~0PFjKIKyHR{E2ggK#7?bHp-$jVRDz1xlz;H;AAi11S$g zBIO_#@LBT`Qm%^F^Va_SCmlR?haHA( zM^5az#n-!v_^%`5ulvDLul5XG{2p=W{21M~dzMSHV_8;qj?R=^6ok|r(#KY;-Ch=O zB?kPF|JLPONcp44pBMeXAIk?Fw;<(LBjeY&gFS|g-rKZM(Rw zENbm+M;Ysy0v<~KhD`1MU2o``#=A(z_5?}|iZomg zC0g$eYZM6C01Zf%F0ik9uIcApjIc6uJmLoo)_^{*gZVxPb!V`(Pwm z3g-~O4FY%!;EY5#;`_(}9Gu4_gVz9#_&ykkaO4kT10xZR_&)Z6emm$hlEz`3Jsvj8 zAnt5NBAgF~ECw9Kjgd4C{uBZ}4RA&>_ztH&BN@B{aI{g1kqq7iIC{R3kqrJPz-I%_ zNCy83aJ1imkqpj;{RqfTMlyIP;3!;1GI+F8pOFln0XRxeMlyIF;3&R~Wbk^xQJyoB z#v$Df0GF{GDga72zACt&`5$xT7dPjg8BWZoW(*O?xoRJJ(2sq+- zVI+fV0Y_=TNCtlaaFkz+WbhM!`vA^Jgrf(F&I6A4SQyFRA3JeIGWeHHoRJLv6X3q! z&zN-j?*nZIDi4fg^rM|PBN-g=z#wiQMl$$vC(cL)Uj?`T^cl(EMS!DpVeOc>qyM8* zpOK7y05r_Ve?~I;;eaFm8Oi7`0Gt4vk&OO2z)`tkB!h1Q9HlcO8N3njFu)l}`w!=x zfTMcBNE(NDy$d*Me;CR5^D*G4dtfAke+D=zKa6Dj`2lcLP8iAPbD(pl?qyN5BpOK9I=YV^G|BR&dp&;A=9Hl2C zX?;{coFAsIN79o4LH;W|8T66A%!kqc6?de_uj%m5(_smeAJo2}58~Sjfc6r`u5U)d zIkOcK5pk9ll-A?{a}wfj$|Gare$h=D8~=-Laj^-%=oTM4=U09&hRf_f#qTwTe)?+y zV^%b~rIAQkttRnRzH} zk|WL1mIarN*#(;Sj+f_^Zz_a`2_aCx$Bq-7(OK<0nCK}<6Ub#HI-ZFG<~7H`BqiFc ztVGA-bD-onn55(|FOH#SyWjJMU|L-KFFf0b%tlU#78h#8e+iKrN`HneE(uN&&~t-% zt7$ROdMRCqoW&~{RHgB0c@r~L8ZVoKBsrf;1Wr2V(F9K&Y9T^ zN^(B&2wKjgiCQ$hU6ynyB+b$ZE^&Dqkh#mjVfh3Hsl#CzI4qmsK;~`u8N(e7I7yYO zVj&+UGtvEbaNg0c2t#Ed551x*54FjQ(@PfBDjrG!7DNZDbxz*OIP@Klfr2yvG{!xE zx{?G9sT4qlmTzf#Z9-o8a%u9tCrBL+z;3Eq4BCsN31biV(vQ(Dntc6+^xF8m4K>b9 z@}JV!>{*I>(m;?nT zLr6jnIR%^H1W_&4CMF=m?t280 zkNL*9fDfk0&038w*+2RY^U{ZI8RDP@8av>w241#f?34pQ{U)Dpw$5$^l52e`&@lGl zZi@6QJS<>+njRjhg?9s08tFGX7@Odo4)bk#cws~>(d1<8;Yt}@rMF`q4wGMQMz_mC&XvDQRo?#NW7U>Kf$F=xugJjHe;99&8Gm&etcoNql zonMGs4~i4wTHG^{YjNke%MsTiHIZwPHHmA{Fo|oi`ySOH&9$hW#I*>ak7XI+TBJiq z%L$JV;FN9|l+6k6$^s694)~-26p=c%3Qr_rEDDYpoeyQV zxD>sDgFFWQL;GCyRXB&DfYgBhaBm)O6Un2n2sm@T^@ z5O1>ydg%kAfruYVLR&5HK3g@MICVCBhq;5!*y=-L=Cx3xZYqXr#_9ZE6i_%{18XG| zpcuBpgVpd6>^OLX!u~sW6pj7&=TXGn`6AxoP-=rwQ35=QEEk?@Y4Wc;iod?s?*D@K z+O2?UJ&8wgHST+ZMAdEFyX|Nm#hib}qv(Q5mlTM(I%?Vb;tRI~Xnpfn9>xE&c@!P= zeK13bNEd%4Qbg~F`EMdp{P*Hep|}S*gLlJD4#jhT7yd}y?Z=|q9``mGuF)Vwdj1^ zSH{pBiWc21Gb@JjiplIE&(a)C+R35Fra2UExd#G=qK^=$z#JTk1kIs% zUU%L9-FSrViKY#4a4244I25^K9Ezh1hobwc)`nHaMXL`lT65ZEZEpi`C=U1J_Ma6j zLL7>A9WWLi01ibwrVWS^5r<+of;xah@r)Aa32&G_ktKuE*|}wPX5p8ld{FpMD7?bE zUtyLX{r!7-Qk<%UE zQtVlXpiJOWtgi>C8MqXCG7+>GxD;Cv)YhBfHsd|;Ao3gTL4!90-;THxqY)GXT#7aX z4SEO7JdL!?z5Csy;KVsZT8(Z;BtG$i=QsBC3V=zeF{_{G=Pc_dO*(|=%lZHzb6!I_ z$7DfqQmf64J~^QZip!pfH{*fLK!P~#YouSHiHH&pOu#`0y-pK|67P{FI+z#7Q74YQ zbrWe4+zR`>uYbs}wrG`TnT@l|sLiz12iY5fjPUOgZDK{+I3e(F(nYn6{{ljP+Z_M8 zD0PjUTchFDP`nC?S3`trT~9TmH*g;Oa`)jUkMy5ynRTsU5ryg=ojsD>c=GC1eJDx{0iRV7OY`-{g z;g6r%hvCJY*E0ehXT|hfZACBdJRXqbK$&*6eT7X5ZG${(H~tA6W@$@`Mb1!V+u3a`v1R@;MK~-nM^d{zPBkPYlQT zIrrgDECK$+)nojLry8Dd@+Z2D@h4U>{E0NBBF&#z$M7e1j`1gMXZRC`r!f49=JEWA zbsj9lpEwf8LHvpKIef&QXxS(NHabf87eWa+;NW$$f&9<#C)U%fiNk3mc8oq6g`*+* zNCO0jzBE^2i(G`I(?@M^R6`$iz|naAM9yUX#M5y3&-fD~X#T_jtKkW|=`@h55lRsetEkp{${SkBNUjwe!N_!Fh$_!Bq(5`W@jEe}6gl8ebR0*Z<5tIU3>*g>px_CaYyzGwi83PoL>oMx<5q(B6ORFZ zA}XbbKk;?oPo$55Kk;>$8J@jyUx)Y;kq1H^*hR|6;g%6o5wNcnW|w}6Khc=w>YFC>m1%tm zA~V2N5+lJf;H_5`7XRy_vO)iI-WniLL>o|f=#^Gen-ky_F^m#J1Z*zp=X!aVB;shY@FD8w{0!GjWewV{AuH7AH!4igIh6%$fM1%d-Q4h%+(a z`W5>d1q^3m?(M7S%*mOUdwUq2U2rfc=F*&rCYm$R&B2*yn!uT;GJ2MX4C~oQW92nK=DtoQbl~dz^_PZr6Cu#7grx&cqM6TMb($b0!w%`v7O+ z@Iu6T#&9O4Gn|RrfHRTsNu@axeSd*7F|NxMI1?%HIL<`ECkSyS${n1E&Am(h31^~V z$8N-#cpW{eymu03qG87}5pX7w<2e)E7|z7uWt^+JteZm^a3(75!qT4vmL{xf!GjTV=nbGZE_J|Hs~& z$2D>FkNz`RCPOj_!yaUtu!)Em78MtC0*JUrq!yQI!lKk7Mx|7-t=WL6sBx*qs%>y- zv9)ciwxYFt5Cx+A7du2<(P@`Ew_M~*p~Si zGcj*%_5i}rCY@8<_|tonh+S@GqG_bg!%V!5J;F@X12Zw?F=k>8$4tb=_W?6;!z)U=jn2GEIW@66c%*2LAn2D>Xij~&Sm`vSZTgLlQbWl;jV9{9B z%la9UnE1K?K(HJKlo=LA#J4(X&M+UJ)!GWE(gy1CKSKdXGIFTtw=tsgVFu=7gc5Wt zadQt38wQbqHh|?J`*2FLQ!lM8;Rl1E+Ra#0%Bc>(Qp)!cO@pkzk%EIVggkkgPmF?a zoYKpz-6bt@p<1IO08kksSBa8!kgTeF6M%B0K;%6j6^sIk3E<~UX|CPx&kA`Z0#RD8 z5`dm^X=K0+t6gjIoeZEnYK(5VVeGfY)LT0Q+e{3ZQ`4^LCwx}2YCO2WB%F$ys%UF`y9e>AL1XDt3|kBiOhYPZRjfhI6J?NNM+l!dCQDF~LG9f@aL!1ZT+!??A$~matARJ}F zxx89u*zF>LH3?)P_DR%M47cf(H{>&78$yQ0%DFVRs#sjxp5(9kqU(GY;t2UdmUx$b z>P}`G;S`=dXvFkB*BXioY>BdAz4eH8*D<{&U) zV;mHS+6vBB#QN{_38%)V%!QP-jxIuZ;+I)-=wjbsSu%^Q-l4{qD8_@z@ah#epJ z5|)g3#Fq#r_V*JFY-6;$_`xFd;AxV8Gnbt7jcvT@h?g;4V4jLFOD-Qd zJ)YytX*QKXRhc$QaC3I4vESQJWe#<+!pZ~DxrjGY9a4?3)i*FyC$u9gofCTi7Ff4w zq26aSK=*F!4{~HdNwxp0SD3sNc$kMrRpQfoG>wL|7$~nLRZo!QL4olCf;95x@VN5Y zFLslkXi;PNl-E-~b4k9n#uw$HB7Dpjig3Ges302Hh^PqLB2vVbS5XnR#VDCq+(np` zKzh$3-m(+#X;e}`GVSX8)Ud8A4(P2%sUw)ABLKCD;y4D-TE5~NZ5zCZ#I}}La-O>g zJ7wQID8g)@rwC*JQxRq)Pt^NBCYPWh{BDQXsmmN>_d|7z$MOBjh<7@O;~G^6r$G)* z0p@IeL)A}YX1?oFq?Djr$`g}?o1HbhIq^}iUg9MKyax4H!UX&_SWx)`FWr4KZX}Mv&J{L) z(5^2dQ4nkMMYYS1Id6S{=Y6x21uyo7-YnV6i(2Z%PizzXI^cAVe3e%@%+R2D84x%5N-nZF zPI6=}>Hl73@W_=qxqYAQWo?wG;I)^XZc{xPAaT#fC*bm#z)-tvV80bLluy(V17phXx7udK+4)J7(L<> zM{LFfpEbeR!Ylxj0|Ut!(y`NL0iz#HJn!}5(#@(x=B@kgBvVU=BL?Y02~-3}caU+A z`~+%-?&K<~ZG+>VMGcD_sCqhiX7Bc$@iA5|k4Foa6c_jj6m>BqAl*LuO^Ao)z%=cB z6er_95ZQ7CU8A9#dA)C1Sgv^X`-jTG_@Y~wnkg-A$6x(x6{n^wdJ46I)nIa=#Y~tK zYF=R*$+7=RXaHJu5D20yyG~p8pc)-_Rl1LdyYWvhDLYMu@J~}UGN?3a+4rn$6;>Y{H&2FX1J%VWQ&q=a?w-+#?suO`&9&@*54gTsQGSzbk(l z>WE!@*b!UUJ^-&y&q?~|{9D_G+}nDT6{(w|M?@l=biYjR{)=8 zIrB1>bCxP}mpz5Fosl;v+tOIm#IVJ|7ZhQ6tvvr`Op=w2tOm^1`?dQXHU$hPpgl2s z2JHzSc0O0CwQVb`vMVd#qsj$i)TO?`6xIeGH|e?n!0Z@UhIX<-DT77if$ zO-kndyW-6)d263|kg*Gfow)iQ6j?Xw{5E^>_A`GI8~;WqT<wn=iA|v&6nv9dg1R%5mwNqfF#ZeyHd5{V z3!Nui0v59lA;Vhu;}Q3&QgloGjY(KujrvnjRdJnd^_02tL)!`e3(2A9V00RsP0;}A zYnssK=&4o?*udO76xg>kpe|Rhwby9e0L`oV>$j)>12EWzD>7s%w(Hyx?WKWZ(g%rLTPaYvY$~^(P_34+qln407@^mM^p{cu-z#65+lx7{%I=+Lq6 zr6YsdC8JvDE6ch7=9~x{EFt>7jkR?PI|5UWodQ53Pjl7ErNO(HW7A zL;b#HYztU}w%ykPXf}mea7Kh^q|pFlT%}x_5VQcc{^VbWLanc2{0laJ=KE5}z(B$N z(fO#4=2mZIH-5R#pF%Wh^<8K?O3>~I7;J<`d^n2J%+%CiJw=c>g(1b-cvfFQ4(PxZqnt(Rrf&`iskV zV#C`0r1tZ!d zYnhl46dG7qCm;&``t4+cl#hy0uQ!$B&38`;^)2|Pu%+6-=INgUXx(Bjb%8Cz{--NC z{-wHN_4*xbLg!ctLAPmlXEw)}LTQ~bs1*4L;$bV1^|{&JP=+@^!BiV#5RF+_wWDrM zlkrO^aLGru8hsaZpxSZPQ}du-uV?gA7QkJ+EpXSIj2N_q1>}Z>A3@DiYMc<&n2xKb z;{riNfW_7rw0T$r~7NX|DM`FX`G+mP8}mYz1OO-hPk;OXcn@F4A~iL$fFJ zpg|(@ob$Iys3I99iEDEoSj6v@?Abr^);j zhk1(aZKe8zc#5Ev=Kb;(7QlInRXXJvD~wLm6V#F=VSIzF0;Rah-i!#x3wg1;d4N1z z?6tH2HoXG3wupgbzqgPgaAzR{Na~1%uzP^94(i;!S^D)$V;2W6&WNQY_q{p%s~HWC zH~DXKQJMZ}0?yhu80Yd{iMuY~LsljrsO$Py5Y%GS zmkP0A4Y7++O~2=g>D9!>#c-)z?5(<(E?7~eI1_Gsb1eylM@A;g%kueV7_ye{bjDUh z%F8q+9p%uUO@_ozKC+!JCMxBEU#(g_vOKO>P-T^Uef;3X!7F*ND0@~3@IjB^5hBA= z3WAWL zQ;?!fyaeOi>LaXVyu2(;G@If3e!()(1(_k+FTYSo?})z#Tt#HQ6Ce?<$Tyw=-Oe;6 zY!ZY6rrRYAr~$vrP#P zTKbFB;~|0!@DM>}jN&`QLo@=Ku(s+Lcz6_fs0nfcneJZ3Re;SsE%*iT5D#%Y zL|D(lN1`>j3)yHr>uaKkAHoKP~>nFLD zXAYuEZy;1+_e|PJ z9qa{&!yy5%^x1t7Dc_c-1WMp;F(_}Hjf5+wh#{u)z{6tQ(8(aI3xHIue+Mb=dIFTM zK8J+W^N{e2p(ymFldl~)sCkqm-Vj0nnK4}hgx?MJBi*`IXF z+#70zLbA@i%&KVRG9@1cp_D-SI2NlQh|3A!5C2u+&yey*kw5i;;7|RTWe73Db<<+AWWfY3l*(jO zk<~97E`ZJ-aktznrrQB}n9BbY;(j0)m4D_7b}67UD+v^xt`riYM!#n00q9|X^uX=T zWJd|Q#gJ`6RbbR0gjUpjiEiQVBdJ`Zcx=v9Pi@Q0pO%!fbpLC0C}*ALebOX@Nl5zETMAiN3Z z=(8Guy9Pc;9(_CNm?nr;n(c?ZH+)*BSj+{HBLovALfV=1Dk$-!@ zAJ5U_4x$`q&2vQV`T;k@U-SHV+|vf!JV%eahim9Ldfas&9|eD&BXWm#asJ{PLOP(A zmONZT&k>~$a`$izJx7oG1Fj)daKvHrxO=#Uo}j zJE}(>uA%41xkI=fuA%4X@#kv@7qtA5-=h1zr-tKCh{`2NiqaJ$^L1ie5yFvXvwkqGnv_zgII8K zhG4SH>J#0fR?RZhcAKtm5U8?=^$do64tga=l47Q)O5hiLzzqMcnd`0ozSFF=60=}B zeXAuDPqtv>DnSQhWZHf7PIwxs9ngaECX)4{;|y3kYqMK<)CH?hatk5aH`GOu8?t7A@l_p!E81kFq>MMu4Q6+ zy&lVA_0!Ji^+(vau;s^83hf z69lOiR&bD^Bbho|eTE!9>RGVH+uJ}91hFGaAXL!vBH{Kr=Jf)G^r7i#)-RR3IC(Q^ zQxaAB4I7M+L*@*|UQdH5wPp_*fFA=Rw6RlZYQ>m|6R8pe@LSl~S`8!o>Y&FkWYjkNf zD|P%atjfPCfUnZKsb`dETH&T^ z&5xvQrx+Isg%itC4mrmmlw5Nfpc#*4j_$}zvzx+uo#PH+73TZKfT)@RP@*O$J6AX`75&dMSc?*0-Z1$FH@uqQn-%wG8 z*LtBvk9`%O<(c|7=#x~HRr28AV8zFgwYAnoqv=cI6Ig+O)G>^Zd>!*uGPPnqCAL}S z2kWQ4OYYtVrXo9>N((D;G8ti%iCnIPI)fOI5ilSyBUj zu3lYJt`k+piQrulkinVG`p_HN@XHuPJhT*F2zAwmH5o!5-u~R`nvpr8=ae;T z9@C*$75;sFDAp8qrCxfWKKwEsepw&hqzXlT;z|H))~tyW>2)>FmGGYBb4{z_5{nSBx=aSt}KBfZ(wJv=JpM_`Cjv2;W$Q0WWRI}H!$yI9WpS%|Nmq}?vt~?Qh48xFJr!9m87+lc}0*jlHSaV z4zVo3w@C&5c!AwuiDs#h^7E)hd$S?fdBW(cFXrd|fgG@7xSh}+bJ__F-m5XK2y^iA zLa07;Nu-ftPxH5V#2EY{k@-8ifS2&u%lcCEa{{&uRx83vnMRB@NwYq|Ah`ZP5FEV- zw>SwNMWNr?ek9Ac6Cq^MfMFR)7c$rqyOZJL`TCSjc83bz1C|H&!vY(G=DYojnha(d zZPxNK{aX#uXPQ#arf3;kNfR^ug1P>^cf9bCsvBYE?EEn$CY@fiS_9U{R6*9P zSY5NmQIl53AEk${NqT%1D#Voqzza+?@8gjO@TBhF7EF1C#Zg_|t=IZOt%;jy}K- z9P{|q6GkHL<|8=DuvB(yjlwMre7{J8*Oc^!yvOri)70NTLOxueG3I)k-0(M`|K;)h zcEFQH9ElEikI~`VZ5Nsu54+EHSI*bg*Qj8w4Sg%vv}B?w!`lS9p>Lo5b`fQ+3MU|R zxCb=A^GbynAzG+^vVS$(-j$x^-jRy8W@A6h9L2G4<2)dl98`fHo_wGGEpt<{Ydd9HFZg1Vd-?{oBqVv|1-nJH$WhmDm z%l^*eYu48&NNx*2NVgxW@D1ZAFc9zs2q(-DVUON!UMzm>_R;@*Ot+6N^B&bL=kIBc zd48Ml*^o}bI>~_Nqm@kHS~(OaBj5%#qN0lR_2$z`k8ZZ(pVh`BcyNoIc>SwcasPUV zbIkcybjFDqEhIR^1T^m(>~rQ3&u>mxum15NZIY%4$)lRv!hkmUXv8akHu-3>X8>*T z(PaO*z>ZCMBxS~YqIoP~qqjL9i+e+iN$M_91NkzK5Y~4#KaAZ6R9wN48tsFWuMhS& z7kD>6$oFX}Hp}VohtdCTv6_A~_3nPIk9;)sQeRW4;%+}oA94bJLw?^w+$Ft{D8Le+ z{&6K=AcrR? z^vDHdR)po3U;;hJa32T#-7!;g37_E}Bj8-CbNa$cbsz%*4E%vV-?f;sWqNlpMlSNm zbF1xN0rV0y(IW>e^S~(Y&~b1MUJNLF;ak;y5z8&>4Je$KuDJ{6Ai|=10DAehXiyx6 zGjIa66#*@c|6)Ug47Ua}piF#3!a+=+hpP}xn=^vD zKnB*eJTkldKUjeASnOfMcQzf9JQi_wwxbe%l=kUcCjN3bhBI)}K*ZTTVaC zjw`(QM*gGmI*=B$P!&#VNx1NrE2n$5?$N|#%{m?-#Vmuap4-^dZ%lvP> zSOfkQmY`7SBJ~-IG!&$FveiLkN)PL6XJ}%apskA zukrljEW9CA$#ENa@zw9e)t--Ae>HAP#*kMtWO3K^1+utr^;cza{b01>+6>!5C168pj&^bH^#}rZ@+R`1m*dI@~WTu&$YcupPQ6O4Xw}OIB zzkASsQNC=$5c6FrZb0HEF{aE7y-b1)Dd(_=i6Zjs9We)9+uQ5v?uZ zc5ZAk1pA-4j8fU3)<>KVQ|~d z{qNF-o^g+swZtWB=IFM}(RJMa#^$sgf3ZvdA7Gbmhf4Tw?9wtkAR`%--G5`33V;Vi zKha{ydbvAwVj(@dKF&WWZK6T1O0pPK<&$WEb4BQ+BW!z8Fu@x!eLwZZ?DDx2kw1Si zO#ja_Oc9E-V)@fWXu92lCY?2P&df))k8C;QfIA_?bt!zo^z(z*yZR=KX`^LfUTPvO|Ub*I3*t{u~`!%qK z-75FOGW!4M8uTK!^RtB|+Ldm15BeC4x+<%>9-o-SLw9QnOy(JF2sYRs5L z+VbVgo-Rb|+N(HDE39poKMhM%=+;LQLbrU?)8+pM(51PVIa#w7m9Ny6t(xjS%yA!P zx(^(}bk-uYxe^zr#oV4{5BKtVG$+g~3hbfI4jh2MinxA+=0s3EY0Th0M?b=u|Kj zB#OY{7&y2- zqRSZAt6p`G03jopq6_O-$R8IB<4}q7XTC;Oy^Ids^BT9?4Fs9VF4&CR6IHuc*w? zrCn5RTFL6}aW`y|r_7Qoqo4nlG2j723iX|Ry_51CQ&QPI?tAW9c-r$WW)nUUuKm;g zl9Q4RFL}OuT%S#%J|H=twaZCmjxAlHG~JTCvM6oKST6iGB(Ejrl&}Y#6n|XlQe~z> z_8}uH?I2RrKiLe`zqFVV*{z9(RVAxG?Ek^oLtF@HXLgy0L1&}9Pht~Kx_)RnF7XXj zuTYxmWhV>}JJ3mKk56C<_3esqZIX!zEKF*~M4Zd__2AO1lyE0~PISVCb0_b! z+==8TbYuU|aFVHkQhJhIIC&Iyyg)L`6CN5M=myvT-rRcj6Z5+=BNtOsLJM1 zFY?DV!2t9BeST{%#NkoyX|!4vpu#`kQy=gpAMnL_nzj0ZwK^W~M+JqQ`nDPkmJ>m? z+Ht416H`y1Yqk8V5k{P`GV*4^oFm`^@Rdf02a>0s#$$l+8RsLw4hmTLNM=qXi4}P1 zm8EQB7J{Lgdl(5~a|JkC0y>2V7LcZHLJ`zt zGqb}ni)y#czS)mnC>oDSK-ogY^XQ3Oe=J zCKh|gPcQDk;#oe>T5_UjEMAIa_GB|yfIPXBJpexr48pW6HhBh3#qG}D1i^Mmq$ZeeC#nWPM*aK{gu2`Z2|q8uK@}2I@mX*fK&Ujpe5i% z%Ou+4#iyr(-Ckfs3ziFIYCgfl!P|ZG4tSM|JcGEX@Wvfuya`-X6}%BwYlBz#CaJ^c zi=)%LdGMmwWF7pDasQr*e#g3hPe;E8yMNDw-_=9FTT=jdtB`=V4k38k%lox6d9|Uk zH27oWtp)`eNdZ@{176o8oNG#lv@&iVis(I`?A$kS$;Mr}G;pbuyCikr<>D?~8MvfF zmm*~Y9*V$2c#Bm7ZjsV&4ahm`cw{hGi^NC39r10@U^8w?hp8L z0Qpnzk7VjHK&E;&l3^tv6I`l59ws6WJF1X}r)WD&&2b+s1Mk$rFHb#%nIlY~=3zU0 zF@p>0-Z;$cV=5q3O50^^I23d>X%Y)Q!kFri+F6(>+4t45xGJ_m-`QP#-M>lGs0Bt) zoT5X56|~rBD4CHj#KHUO3W1ZkEGijrdXXJ)`d-B27eIIf{LL4;BjT^pg$v;2L$hlf zc?;apN6TQC1&)4#+ypc`hKwWVn5E{MX9;33?3f?IHR*&v6^JwHUxXSqN?^wTUSy(c z6$!0E`H>X!5F686s}L-);|LQV2d0fmZPH)hWKSSjp--(s{RNU0B)VnK6Lh0yr%m2# z7i4%7=@>C4VS|mBfaMHFuuG*35I9yj?i4!ZHyC4F^yYQ!S?bd54O_yJ)Hb<*Z`Ftt zX-ydQ@mynXhpp%*<|`k*aDz%|hgCvj5!J#_i@hT6N)Q^Wfe#d=tSQQBHSq#lq*9iQ z5VQydvoeEOAumr!{pJZ`=Q>?)Do5J%*61qxkGvs@SnE01KY00tzJ$!?&?sV-ze&S- z57zp(Fuf}KH-@kacv!nLL6usdnVduteKy~06=e(;d<{=ojBmq>yr+b9I78cgvu=S! zX;ipG%CuR`SAN26>7GN}R}%iVy~X-h3tN~C1vGVD8(V9X3uck6F7JCb?>idr#a{b- z_){TzN|;O+7tmP>xoJW_-`c@%k2o;iWlFD$u_zgWd#~<6UT1Q79lXt!{(=7t_wNSu zd!GCENvLuDOXxi!1C-|xsHQszRGzP*>VG;^@U6=B$xfa>sykeJsbw(ioEIu5M|~tP z(~9Odp`1jk3E?c>M77WL*{Z5PBh=(|Rxm<+fAGN&wU01V-T$ZP`aN|kZbmYCJx^#+ z(?T*w%~YxNbTJ!J=%bk0*H z54Zk43Qq3)xw)Bi$Z^uhBxl0{__{xv<<9W8e!eeBMdvTFsVVScU7^>XKZ)R*m%>-V z%#oWdjC3AqzH8kQV0@wN6LI(Wp28}9#LXiaZQcO(?X3Z@@9j(g``%Lk*!MmXTZhCZ zfEX4DVx$j9zK@hBiQP_U+Fi?H*37DXwx>> z5x8h=nwB+{{*8r;LsdO|APet+iwzI&oZ`M10~g~TT>N?9PPnKP`$zDu@{A2^K+g;V z)0>X!h!&nA2=Hg$Rw?7zrGRYnQp(F3s&<2sGKZDLQAu%%&big>NaMDuISz52YGa@V zA6h4TTaiC>aW@gK6TZEmq3SJL;IV4WxANOeN6qMJS58%uK&P%!RqL=SHRFeq%~I8T zy)yw8S0`y&uFzK$HqHKaUy2{Fn^CHkEtHNDl;OmeQVjzr7FO6}i_x2Q%B?sUA?P{3 zERu{Z&`bkVNQpk=1ic_Uq0eGSGB9$Rs;nzRx5jRdB3cDTjh;AJFL7RXHf@&6{h+NL z)YqCQXqNtB4+1m8hT=bnoQ*7YLtdh&TR)FK(gJKIFv?c`uXQFD53eg^j4N%)KJ+T; zB;C^zn#tPudr3sR$>URGHrZO9$qvd8KF0)UTEcxwBITKPS^!%%;+CsnvQ;cNfkoZ6 zXarfYsQNcb!Hz0704OZx$NDS=d(|4m$+UPw<5T&aKKw$S zO3Cv7?5)w~O=PL$@?VwSCTe$jM{{}e;F2$;d+p!MUC{*!aw?~3R<#!~vpoiqfKZ$|V6F8$~96%x5L59xq zs(~QeeM5mVrg5DHGkQz}W;7v(wMRnpWc(Lq@;$>F&H&7$$d~Zj&t3%|qMGP--?Tz<)4T;N{Vi?w0)VLAXKj{#lZMjzZjmSpwvd_1sV*vS z2Y#CT++r|5ugj%+^%%L4K1Esprjrfi9784V%^)0jgdL<@-A0lp15mrrOUk=gy10>8 z$O1_3d_Ckqi(&B9C`c9KF}Z$kTEF>u)Wx(0^3|nG2;9+*5K+-nfitDw_Fh(CCqaNR zAu}{kNIr$0r14meK7J)1P&eZc1Bw&_*W<~`rOsE1?XhQ!R{%HyKvJV!v(3Qi{gfsT z08N^f#L%97S%77Frm;J{l@AE1%lk*QYZ{&Dmrc~`z^17yZdedmSB80~DOzcDOuB?~G z)RnDX*xGPC=m1Z}rG-C!$Qv1#N(8TUxFlyHMfF_@dCat2gWJMbH9Ikih;L)>L~qF@ zQGaYPqR%d-31_J3)uP{$KcUMTlDlZa9tqBc453<&cd|nJzJMJnHy6yg=Fs1&(M%g9D2>W<5u9yOEi!&bG#dE(J8j_)G<7Cq z23v7?URbS}zE9Ws_OS7E5306CpLX9JfE^7ci(Iu%k4?UFXTt<*h?hl0jCBOe-&NLV zvdPLE79Udvknr^UJ3jAF3Ag*PC^w+V71%oIsS6WhEYAj;HBVJ9O4_UC7ew=T%otyq zOS*-zOu8>^qDqvB-N|h<(G!{YrF1ES#+h|ssS)jRhUtOiSDO_79V+%;7lbn z7dkrQ6ANArY_Trmy`HHb%d0-7O}62N4=_~qIdwB-|h`*2vh~}G6rZK8Q@~xxfIS~3u(ZEm=cZwKMv0eclTC|vG1uF{NBSS{|L!EBdEid5S zva)SaWRo$ANFWEyQa{0&C2cA1&Y^47mv}iu0y}U?1Bha-AC4k$?u^rb4ZynqQ=rGyD%LEHW zhETy^N-Mhcqmupm?A?kx7GoTBRIo-Z{==cm)NBPfa#%k~PHT})bTK?%KY>e5CGJ(B zz{iY@w)zTZ$rlIT(s(z9#3Hm7#Bj0KZB32T?7bIV1FRwN9-Ahd?S5Gv2@PR2hEKs2 zV2tt+HtHtk?dc(yUZ{)+Iy9eFK*`B@N6?~cM(k{U$YB-OyYbg_+2$v??j%!f{(cHF zlz|5}Q+X_aT77SQVNoHf-wyt`i_^VudWKJJl0ab|upCWtVIkQX_7rO6fx=2mxMND% zRU9W&liLAUe20ILx~pkn4ryqr)$kf`Ep%_DC3GlAP_effRJ`x-2I$r-Qh1{NOJBNK zTHyz3nc{Y2K%!9lTxIpj0&AX@EHJ?X`9{o&tAg%={}ku6 zY-*dBaB4zxmM*`CKf5D!S4^C6%21=q9>XLumO;irWu$8eG}hVsXc)8aElv~8CM_Qb z04|-dK!FuIwx3@HzE(4@Icz&eH9;Qle69(a{hdsJcxzMxwz<4fVL@!MfC50}GWSx_ zs**H9$qXR@O;t6u)T?5rAs|nYmVqtynxnO4G|VwfOus`Vg)mt zvJm5>gNbzu#3jo1BthB%n9a%5W3U$)RPkWwuN>wL6$HNTPr>PMA=J7CtJq1J(R0Jz zym%ZKsKF?OhPI#-qThBA4bUkQzw4Uon1aeqyRQS#!ps}ap;>l~b0`#JfmS8w%#EZb z^p!WD=s3^K4w{C*ntObB{<%}Zd^Dhbk|SfNfFH&iUS}L;GY&fj9X@sr>TOHjLr~)1 zRkAk+wKTC4!%BOYnehHiEHt{jV~C?Wudd2LE^2~OzmL>G$~75*bm*`=mDb=hS$I^? zs8r7q#Rf=3qv_2!LP5TRaEYOpAQ>C@(*y6bL9MiL$DmAF_jHR3+T?_LQ2}LB$!to- zmF&vy_j%6_N=g&Z0?D5_Rb9SD9~1Bh?fqL8THdzp1u>56nX5f+)-Rw?>kU3w^QEaC^ZbqDgU^#M(pu|)QJ4Tmld;jW_0lyL#=zI`Z$)~1BK zlaRV?rAo<5iTV9A-IcOlQXOKs#54KDzRVj$xtWvQ5hrX5Av+R#cu561U?2wG|5?)e z64*5GV3sH&k#1C<|}7joxm{| z1Mja84-TXz(N(U9qtr*33_th^ihKu(yu%q}oKJ#*+`Hj#gNjJkr?0(94m}0k|2TK| zKXcqmbCgs25#IX_@p1n>qi{(T4ox@2oBcS^3Z%&7e23*vJDw!)Qj6w8v=}g9&BI|s zu2Qz76YF>fRUysaAcIw*O^UQFOahfZ<=(gQ?D)F3Rn__fMQ6snsc_4edC>F#hZ)~p8R!B6 zj&?^>XqListyVxntIB2mR>|tXe20Er(lH_QXi-gVPP1gH|Io~5BL|~ysvfeT$Pfr- z89_B2ra6)n>+PJPu zbX|-Zz}m_5c~wS@rx#RFINW#NM}E-9i&(S%?+ky=!fPYGp^2su-`Ko510HndT+5pj zfont1xEG2s9%|l$#=X$ZJI*Zs?WIMd^Lm4l$h!fJPrfY;p%eXS2RsG@lLQxdAASCc z1oDBK64=r8!bcqm&N0^SyvUoQQ7`D$AT&TcO8wxF8HL(t_tRdVBn&L z(I=FHV!hqeAv;isdL(J1V2rP?d5$D+!z>E)%gm85KL?TGvWQ<7b$rEcI5%uarmq+b zqR6pE)mwufv?`aRFzO?q%~h_PMor5ZC2w%GmY?hx~6d%^WUYLtJHh-j8>X*Z~X8>)lotjTY} z|CWU2@m;1Q*v8>E*9}8@dvXtdygcyOAOdw2$Jfn!kL45P3Zf_9$fHmT@J;t zjt8x2zbPH{!4vlg^4ptP;R?xMMbIyvzPIDWABfzZY6CLN^SnXntjK8I)n=BZr35OF zb?K;!+y`SZG&bXng~}YKDwmU%dAZ)hQOnvs?`yCxSw9+_I@A8O8nsV*H_VW4MgNmN z)A>eep_Q3E`OXWB?@=m7_aXOObD&J`6QdMOoU;s;m>6UUj_a4?lILjK&C|8bP(WHoB-=2&*FQ zjRmHhm4T5#ooZrx+uJiHHtzV54~V2VtD@{ zQuv65YVj_2*IQJ%{Uc~bVse|LXkwy^B@$p^fSZ`G`ElKR!*yfVTYQ1(eVeSv!5go4 z&(dgp)XK;k^39GPe%!@vltOEEYGTToVA7&jrNgi1vdYc#E~6b(xC`QxbjXAmsJE=n zhO9Asy7ELvrCRmRAyq{T1KZsgJ5Az(nd^i2M(wBnG@ykHd%K)>=V16t zg)GcQ958M*iN%MoKV0F)v}z`S*4SWE0*7T8)3S(74kj(ow=&<}E-sSA?>_Oi;#(VU zdc0iIi2r#{i9tfw0h+fF|Mm`?bQ1@?O=!_+ZgpMVdY7-UF2=gvAT~x+)#dYd*;zd< zxeJ90dZQs^u38vQtVYvl-M3aN$vCBa74=Tg9gW?u=ECnzJ@hl$lPKr!PxT)S(Enqz zLQo_RjXDM!rSV{fQT&r!&@_~ITW-_ZrRM7Ut+eItSXId+YMW5tbOiJyk%GCDCMmz! z5%&b8vGL|o(1E5={k5*&MZVJvQL9tIwmrlNIhswL-jSv3vWrf_$kWRt093t_9`QVS!eaeKI zbkv(7AXis_mHz~TOe$|0Zk)1O37u5{K{q0O8rJVdfab6LOCzb3Fr#jd$)e)lM&BY&CtjKxgULERU9uM#Ip~BAUD{?PwDmc7bQ>Uxm6&UW+#2buk#scyP_4sgcS_W z!RRv6z73b}8Ai8A1xZr@!5pBk6Bo`B70pr&A;!SOY!O;yELH@D&IOWS*<#T&$S*>f)|cpt8kCFP;2Dz1KWyp$U={#&nn{f$M8?V zeC0L9q#8#!eKN)t=n+hQKl6&pvE49(_*LGX2^`k%2lJtoPvFBO$v7xM z%tkwzNSUg_l9aJf@KQ@Yw*|noM|VPd22~cJd|8u9!^qY-pU|i?Mni>Dwfi18-ObM~ zX#?~#RV#Bpj;=t+4T?1o_vTWkGO9bO_N1-? z+n!=&8|^C4Y&J*>xemzq(6Xfb&!7*odgI=-t^&tSRkp!tXMF#J8AFoL>|`Lef|CK4 zudb0uG|4a#uPkVmPO|TI3b##FO(!ISevW{-4hH|e;ZqD+n=xn#=eTa8<3>_v+q%KB z6Y`#74SKMGp&SXS1ys5G`yq`&Jg8?#!|55J?Ma=VVN$Smp6>lY`#xtd1s!=9C8(Em zY{$WZH%l|pp!c3RZ_h*}&xuCgCTu4T5rPe?deA)jDrXQB-ue%D^XQe(Ez-ZM540?R z-G78n^%=a@F#Om(#cu0VIS}I?lkz`jlv(KjPKPO>tFFz4X9!q^ehMsH|Im(!LwZlM zRhqx?o-!L3*4Z6s`blbjf=r|@FJs8U;04UFNt$;QtAn{Y`|j$&@^i3=u{XxP(I93Q zRh^WZPC=(KKfht4Vt4{`4ufe4!S7THbU%mx7zrIJUzo9+wfal_3G*BGg7a6Bq{W%7Oy_xWLH2Bknq)?W6%$Gq|%-?oS!evPHC)}Yy24Xwm)_isuDbg ztzg8R&oqOf;3OdpoWt)!fB()#gRE(X@Ndjr-j1qQ!AQ5IlBR+!nkB&nQK>K{ze!z6 zg|@c3licK)sAcAulBI2_%nqp_OL#T7EtFqEK&A%N6=5w$4SHLp6fK)EZ{RRYtF0~( z6p~r%iN{yYZl;i}^o}kHDsg%w1K4feUwRm?_I_+y81F@PL>kE zysxRYltwd?P~z^FHu-mU4|MsvsLQD^m7TZW7AS>A#Ddz@H0hV9UHvdneN8&A*(jHC zJ#n4Mr+=x=(37j)?&Rg%n?zCWZCNzMU@;9N=9pxyG;f?VQ9(+^2ix^Vcc*SK4YOnh z1y)FkHWI}>EulUmPE@4&7f(cETtWx0bZVUpIUw8C=AA4cspvl zv5rCOg3$lT(-cp@gx4gB85$wHwwFm<(Z7+MGb9KORL97SC53@>n6|A@FnT} z)VMD^9dZ4$rU}w^sQB)tO7o3vrsM~bZ`csPZ613d!f41g!Ca$7Ro4=X*da6bZe4P7#Xp@*<~GBE&^ zzla{fqW#c=T1Am=nP+^!Tg7nnSXtjh8$8VQH)}5t*9YakP#WBy`5uI8LCAgeR>T`0 z=6ejD{K__hhaEW`ARM~&$HJ*cqBwmH_HRWGNp&WhC#&I7k{5(!k5$0~GCk3tf978F zSX{?;^bpMpzT87J8}hu)npn8>7F;@Ldx^U}w-%m>gGa7{ug0LK>OMw~iIKZ72tCpd zRIM*U(P^26qO%(X;Yz`Ql-?2xL7eV%KZBzxLoyep!*gdFPPv0H^`g-8_oBzmjz@Jv zl%E5c+}?r5d@iB=0!;T%OdDQ9se0EBVrp22Vyd_90K`2lM$yUK_F`#pTU-ewOR(?< z*o9#~5-yGY0jTwxQJ`>R^cfNLV)sJsfl|4r0L{I{A+y#kzwi^kXO=ojard>`d;> zg@mep3JEhaknof#cR_f{Waho%k|9H~KA=ViAT{2I_W3H=@&rUjd|b?>*u;fG?r9_E zyC_h{9F!Mb+p>#3bInHy-ZE4T8m!OOWv01c;95HFy;9;Oh6?`PDb@ zc+`u-;8ODmA>^0!ILfcvfe_QHSGZ>hr>a)E3vJp7q)%ed4k*mmk@CF@P$KM=Kq5FQ zP!LxVAuEMl?g)1Oh-6$(ff_{T_W#k|`@lt2<^TWlk70lT2K+btNeA!`BH|1n;Gc2^ zP|?UnBSqI@hW{vkMj)axn?Hz{SXg9iSq)_^mF;4hOJy5~C6%>Q*0N!{fR0zVCDHx#ygF&bjaNJm;>+f5Uzo-C*?`6hYdb zmqkBYpjk9?@genpPNH4$zo&ipPv%f=eDC)G>c$6)4&EZe!}M*)Hr^KzFTgT8&EsCqqqG!?y14$<-j21}3r~ z&$A#HKjiM3EuYVVd|+VD+TAAHo$@v{(V%tt1WrcGa2b&U3qushaPu3+4ZCKil_>Ao zQ5te%f(mJNKj=7%0{Vsp=`3Q8@Dg^Fg`pYxXA9cj)#2V<4=V`*O`l6aEKix-xV7v3 z=_&5pplL>*I#Nawy%4>UR`G{=9NkV`<1+b5$t%_Fu9fH7v-{jk?!*cDLRrFIjR( zy(o-=NIuk273xz`2ELv|SLr|GNRPs|C;TbL_LS4T$Zm28kNn{Zx@gxO@+T3))NqQy zHlU+#Uif>{hYYGySPb2ce8ey%fMRIaM7P%3=%NN;gGkSvx=W03ke#(Q12`oVjWL&9 zXbZ5u*u^ZkpN4Yh}V1jgNcsTaO?a0zZrIzCw8>P;mJ!Cb}M zlNz@S%(WUHPnBeK7W*`=cKYJos6&R~f?}oqr12;nsZ-6Fv(DJkG=3Iaxaq_=!TK84 zR@{wuPvWPO6~B=v3;yU_*VNOw7;jr#cgfwD0I3PJIETma;CRRS6G~M#D z8(m?Fz3X|rWNQ@+6zKAIhPsmQlHpt-r`U!h%B{*jxhGDI*I(U^!Rg_Mea=b;aDxt$ z&9_S3HMtWV>884_Aez9Q(_yTd=?*!dXov_C%xS|eW=^wtn_{zEfwxgWy>xb;Ua7~M z^kC4yc$x2#qnK5!I}$6!+nx8G2p~%p7G-@EfYiO_u9$wBu?lAci*7|WhDQOL@DZOO zSi-Ebp!NFa52&|SdXMi_`-KlZGR>@vIb+ONnBQmDIi>CK**;o~+=lFUow}30Xj$yc z_A0&$=JuMFJk^n|iG?98PY}TXfW>5Ue6q@zQ0TK~f$B)VPP86shwnO%S)gk}>fTGvvb%pBgPcgOua(}(@0>&whD{0Q) zl(Nevzp7aQ>zH|!{-izd4k-qC<$ABj18-|Z3Fg?lo>Z8!76~P2n&v#et(i*jRRxvc zKB;nX@Gn|f30nM2&cwo9S*XKbB+|L3(FI7$$ zcaQ?xVol(zJ~D1KCchg~TBS}7>U;$(-Nbn?Fm9nbEXv zar>nA_P+9#L2t_0W7M4U5-LDMoLB+&y%K9mLIwDEngtcWuAEf2GeS|m(sTo>0AXSU zc!?^&`)!svo)zGOq>&2HM-`y=a4D+*ZSli?DF3gVDF4`G)hF8X&bZRNyCTdxU)Gkl zyEIalLGfPC&b9T1XgS&5TQ^KN%Eoxwo%ZBUw6ebFonVDNbVU)SR;yW$?+5lau9fNa zdjl+jDYi^8zOiAR@_p6u#~zsYdp`wUX?@=5uII+b!-Da#sWYk2r_}% zQ@!xhuryWWW3nM+^bzOv_ztZ;cbm^2F(3^8%Sy*B{dn8!Rz!lJC<-}ZROgP=kX+2+ zc&BP>Etp`_bV;&9>^?Lx{4BkJUNNbq;-ELcuDvD;er_}>yMFI&%DPp2tgmv;3M8_J z--dIpw!BGx)(CrHLW;|M%C+LGu}+r>Bl8C?c)|@j8-yGDPF0N0Ndj6r)4An!qh5Y0 zgvP(kzG#s4N<$T^r-19}_IGPqOasTZ4~ENUDi*|lWcKaCd|h)7Owm&-V>UQ+@;uez z7BqHA3Zd*RpqlJY-b5o0QH_Td}_PE9Dl0a*Ij1F0^=q z(YqLFxsc&i=)@GpUsVrKtx99B;sN76n5s?7_a1Qzr?Rje<^)9$%n6P#OA7V^l1QGIZyes4AeBfB63IOL=Ib!fF&z z$iWy4&ny=P%;O7#1OWVbI@QEJ%as~MqsLMsLWmZ0?l>OSr&g&bMW|=6kWO(H5!zPe_)o* z=+&)GPG72BFVW0BDa*3^rF?`5d&5UM?c%jrfQetGx7$?BiM-bwQd4g+y}A)!5EG5p z%puqcOgg;pGt%DmV1qoRR-RlePs*0h&5$Q*<;T`FC+J=Ay8Sxs*-&Bjt;M%jx-IN= z9a)SVDpe9C(o|~FXPV_2q6-_Ibo;wM==ASshLtpSrn0M3eO!h`l7Z*#f#-~YAZotQ z1$?5@4t526ZVxE*QK=g+E>KhX*As~EqJfFkhvT^y68ZJ z5XKcRdNBIYko`l9KB#=V!V8T4Vl#HiT&X;0m>8rm8FdMgqwz8tlvJtC>I2Uv?8f>b zAoeD`+^Eb`=v2<(`e=MJ6aFQ}H~&^#KIJgy@izgS-}{g8_ z$*hKTjz7J`p)2;}@t+iVd^_>@4O@TyBObqUYgVUkrpa$&wscml7xDNvg2y*|J+rQ1 z>(e@c#~;$Q6OSJa9{)x?c>EjB$QS8?7>~csgU5gW$2|T94<5hYJ#EOj`%*lc9 zJbqc<7#{zSbqtSx^3~bk@&9ZSdHjz=9zQptb_|dI^$3q|bL)x6|Fz}>@%T;hos7rN zl&c#S6OS+9JbseKLOlNF+!tCCM|u214vRv;c>EU58KnU**kg-`NMkQ#b=Uz0CX$CqMH&v*SFabJ^MeWWID)BYk0eFYmBdZ%=U`WM-Q zw~yfYLy^Rs;><}|!=9tdv|h_KE_<+PQRA^2!+WiBdaz7Wr5M3MwLxEw%u5#^of=M; zyxNADRHtK7w=Jz{GHgKYI_ds9Y1a3Bxq7+Ej~VFh0kM2@(|UWX&QA>tp_!(j>}-HA zC1BVd3KMab`e!}i-OCjzE2Xm~J9pUK3#r(pcc9qQx?xp3x)FWs8V=ikXw(9>H-+qA zb9@Bb!=zY<>DE1ItV6|omfSJdX{dWGnawc5>gY%??++Lsm54852V?Pt_f7XiMyC$a zx&hfFVAaaceWGuqNgYX)#rR#Mu6Q-sH7!vVo7~jg8p{YMl zN@`AZj{714LtE22wQ;{=BX~!>H>Q`HVu7OCm%bjnRu{%5*HRq*PQR%HH=c-;KGo*W z2k?>67{F(K_N3Rw-CYkv`&G(`*$PAJzl=zZGF`P}8}3IzOqGG7sSElGBqJ_!>dj{ev{gd#h%B zJ3fQm^cc(HI)Jaez8?l5lPhTwGHw=8cwaxcVs0^JiBs=;>N@{*&I?L*}Pr8p|DLS+1)dwUXlTs13>Zjq`vM^ zKUjdBa#xvXgYGOnJ1wU$ zXEpnL_OmpsQ>q(k&|GxCrbzYLiL(FD*nXeog@LspkL}J8R_+{fm=*a~5k%1w0;)&i zniIVvyl%rbOS68Nmy-oC;tRccZQF3&wO_T>F$;YX9!-nVs=;K{vek2{)TX5|-Tv0j zm|V~h%3Mvec1WGo8MDq5!Z|GI;2x9yYCY)cI-Rmzqa-78T5!AXOKc&)Nu|1H3$<$Y zq*O1g=cjqfK`a}vgr(^3^#9x>kh99vpyTn)(h$4py{-(^+ksutrT_uID${B4)Q70E zHp@a%@;EfoyWoXWfqR$8a+|sWHWaLWBJ2xYn=4G8E6h&z+i$O%CEs!7Lg4xGGcaIx zZ2f~_Qq@YIQ|pxxDfc_pq!$6C*>0Dp?qZCSf8BjPJvh|KL+v|$rH-#NbzFG8+7+s@ z=s(gcQ?+||!wB%V)A7A}>Ds~jq$9xA+23+ttN*HRR?S&ropV&zcXfN*{sgR5PB&t$ zveL_GG5Zb47sMMx;9QgX<*IQ96h-`gjEla0cHH>oxS!SueT#2+B5ZwxW$WnuSnBHj zL%!psH6!=q_3dZ(Z}IUvlB0|my5D+ZCGJPH!;Vp!nyrQ!`n2W<7IiS%M7Q-^L2RMV zTlQ6$=hYH)s@vSzuLQ7-VWB2A!;jw24>|1P(Da3Bznn<;5*o!Bs2Fp#ZQ0%m|GGp@ z5!AMR8T1I}+s38-yu+PEoU-&hb-JPLCgW36X`a@1yD>&%H~P?%^W`{LU%t{LS*fJcTS7 z_4+nTU;HOaFsXZFn(!@)aM*v z{N$Furl}*^*HhKL#-VzEK#z?_wBVH45Xm-}_D} zk;LwY0prhQF#b&(#@`gxX5Tp$#vlAUVElBz_!3VTe@ejgqgQufzVz5$foukhKcB() zO_Im$V_^Kp8H_*cVi{} z%8iFuP52fyLAgn_t$(4jU2mNEF(B47CJ!{7UbXP1TN=*}9)=>_0SgIT%$x=R&r7%@`uzI`J$K z`F}n-@ObB}08h3?QUn_sB^b6d={%7Ajt7+k`kmdzt^M7ml+6FIm`YgA4k)70XW<6* zxps!qw|Z}$SlY+)H16)mUlTRiCcr#W5!6zKY%5xTx1BunQ>*B`nBUH|OjZx~%aUoJnD zOmuyRNY^K6hG(^uH-WC-DAM(1y7AKI7+r51q3c&?H2h(duJ>06bp4&zW_bHt!Kszk zcOEu2e6j-T$~V9C+|OLPKGmj3|1`lnB!KN5&b$)NfpPx%Tp!kz<%M4x5Cp1Vl zZ=~0~fYb#|XQ+kg{~Zt; z>Ew?ibSK8qGSDQ=o#49wt*e0!O9IVz{jR1Azm+lXn`*t4lN&x&DR(~1miwx(g@G1m zez~>|tE(Ff27Say*wiPz`&oUI7SCw495y_B0LT8;U6!Ul`4A&p?=JND1eSfxX(>_l zN1D9j5VKUe) z@dqZJbNt3t%R!00d6OTD`us{-ls*&&D|lbpm5PH~WX>glG~$c&RwireBG_Vg^t?b> z>{E&_C=@?`#NIyr4-9Hh=jZ*76EVz-z- zy6*NqhnHk*B9d+1Kmc7f!|JUrw)Fg}@S4((oc_rR)8BkE_)OZgeHdt(7NZ6DOM4&& znZ{40S7n~JNcYhbWm0C<-=1-%@37CprP8c=|K#Yp->#J0XFV!0(ITtvD?a(9)=9iw zu=}2$AwP8fWKNy(RRq7~rYo=e`tfzH^)6y%TLWd_O7YntVtTRzKaU9bzDt(VO=RAb zfFflj<^&g=O$!S*X(k>7nq&4Ri+`zHI#GtzIN|a~yk@M~_Te*L)d25@7uG4glZNBsrk?!8Q)TA_AR$N%uQLE1O48n6)oqDS^coHiXW=b2RFG(H3km z1Y(8j>^*hsk9Ui!ex-lCFKp2+>=cVeH~LC6x|;1xs)hhStFc;IEN%IQUIp5tSc@B` z8pMVLZpD79PPf75O5#WdhBg#TQYUNn(<+F&1|6ow(c56GUL1q}u?k{YV;!x6m^N_E zk%9L&#WJjd2*1QCh)oZ^i0(b`?b{A0rmysQ6G=DHDu|3d>F9u(witu&U6yUHq7JA@ zdY^Sl@W7#)Wq+mep;ZuDJXS&2wtY9c3gT0ky^Tay>9^lAVr{C{+n#?mey(3abdX$wgxG|7`)#)`eMzfn z^i@3B{|C1-HAb?iW1H?SU}XwzbYv5=jeRz$X}OY28PTm5_R8*ey?$f)#)CR>jwk?w`lzKh?>pYb-PB-^5J=N6{DWmR*yX zVtB(oF3H+r?y>vCX#tultPZi+IeKjX8a3Sv-fU;a_q?K;_M{!FFX5uL|fK`?6+M~VPsP*t5(N#toPj#pCxg2 zVVS{W+VIGcWd?B}3g-glpmt9#=h+YAeULm5!LvQFhU+bLiss&Np&>84Tiz5rKLKZCzkY8bl z1|KXpYv0sq*`e#I=r?sOBM?Q20he zC__LMRcbOEOltu{ox0xZ!0^@%n4EjB*YCKY!Eaozz@0A&*raXG*O<*|08{VB+K### z)itO3w@EXVUa)VBELq;1pi+)=4*j}+g8a)V20a6yQ@p-Z%c{;Zc;tqtT8z^;zre(B73XzH3ooNS`du`A~{IxAR_8VyRK_ zycQ#9*Ut_v)@t9o8_$jr(U!+f};U<>nPoNsP^L{vSzga<#q=hFBG9%=aLIB97SZ<=dX*m57gTO zFe(_BeKJyCqfYwj58KmbD>sF-(4DA2*J>uouTWO_mYX`g%6vD#{*P}C#2d#fz-=tH zDQS7Ix=WH3b|POwU6V@_Kr;aoi-H=B?2A~s{(K2;sM}vQkMNiELsWU&ercarw;ug` z`z*It$He4c(hKFMgI5}^^ydYI_S4-HN-P|HTudU3JOfX}%Eyowb04CcCJ z<~6Nxuj=Mt@c8|N&vSeH{Du?kjsa!Ym9)&JZ^w)1gqtvswTyo&(UO`t{99wPYo78! zweuvHTH(YyiIydas<-WQA|m4rx9p*j#U-uRj4Uqc$Kh`2oUbBN9xX)vH%ynybZ=t) z71kEj6=C(n$$ZUjscA0Rm>4j9wsVCmm6qkq%s`z1`!6}3-RWoO*x|b(0V_t`OXL{& z^!mN1@>Q=$u(Q=pMsHf}wBsd}a{ly>CY+Vs?Db;r4c0ht^jjvJbvAmv*uTaaR~T~I zIj7DMkea{sfaY9kzSr|;xY83OuFsU={&k!pmR#?8N&SU&x0^D?F{dwEGqMp%mj6(E zA1y>$N+}rH@iJ=S_qlFYdtTR^hTqupT8X=#?0>}F7lqBph^FsoO?EL5KO!kLikn?C z-g7UNPxe>k%cUzF<;wVDP4}Bns9upFlDA8By0#k_klj{7c)(*3p&+o|ex%DwvB?X# z2&)wSb;D|>K`|}K#pl6{p?NC*G6L}D@vWSRsMw`>!Xsg;?antuZ!NP&j4w3 z!npH@^IPFD+b-XroEM5Ql={Utb-EvBVI{sPN#45DWvaI%xjt4C^d2g_N$I_6QphSj z*xj0hMFucTzZcz>PHV?bQNOyPJCXX;eO@PuUGKMJ;RdZKNjhm!91LBr!)#A}yCnFB zwaWLich1wMsXk6j0kuDbL7tB=b=GXc1F3jb0{SQ_w7Ki)u60Ibyv3vxmQ=m3!;%eI z9y+9T|ADPjl+KB;*3z<|AS|ZY*! zGeQ-QhcyC*#rlt$er2l4nwlZb+vsJFek7r{tr=k`7J2F6CHngQUv#x76M^KnWj6W7 zJ!ugJcFXivOTy$69vKR4-l+KjZvssbv);92q7KjQysN$Mjgd(ccZF-9d~QeI0!SqD z(e-|ytHvsqFA}3w3;mT5o$p$T-?bFF);jz6OOf`Y&K2^Z5nz5vGJx2HyjORvda20LYZ*%)LoB_OvGje6rGJoEdOKt3-)AiS zhXPA)Wh{L!W9i?&_bj&KWh{LuW9h$G_U-LPf^qc`ljO@l^`B-zGLF#>#AypaPh%~KK?Ss{_ihIB zKcJ>W_UAMBJXT(UAgv)8vCkG@p&uwaSDA7ZxcaveC`jkS3 zeN8CU)gghaZ(|8t!613v(*(&Y*D*+ba0A^->lzlMKQv>`vjy7UF-ZP-7UUfkr1zH& z1rYaI|GnDZ#}AfL5JwS1@dd74^&v~x2g-YZ;(wn+Ii`Ps?#1^VaX;qv->{2*=f5}P z_wjZc#SpQT#h?#o&(r|LkQU$_ypwVDh@rQe-I_iCs|;2(ZxWU^_|6k|akd;|XBUCG z2att7t4<>Mn=CzcfhFDl4_JDC#lMWDFO~lnSbBf3^pdKXZ4o0Z{f9=d^vptbNtPaO zTwv+RxF)jnD}R!u9~jHh+rZN6M3(-9hO_jai!8lN_V=^&Wje6*(oa^I)5fs$VJfin zWJ$U-OAjU{&FJWA$>S{j=Rd{LC+!)-(ub`8OFva$>7AjuB1@0i9)YDd{?jadsrgcaLT12f)(f4NhR` zc?|(ej~%&0mR>6TyIFcX)&fhf1WPa6#^1xzD~YA28ysQjKmXUU^sozprAOKSb1Xe} zW$ACbmW|0ON--EWLU72kBNUQNVINQ1dd{y&gN(pG}YfijM*Z|7Y(g&ozAh2>*{V z?%oKK5X3fZ-}CD8c`9wF`9@}H#1_3#su%od3OZZgWAySeUxT@cCTvPsE5)Xi4Gv5z z)ZqIa&RO!tzB#h-;AW*sm=3d<3pL8@u3v0b=G!OialEhDH=)QTpU~)79k#dY`X=SN zwDu}>o~GeQc$%%nbj@n4nD{~hZnH&$7x4+5j@=UPb)nwC1g0zVRGWvC&0)ea2*8pZ zb4PY8cRlQOS?-bsDA!_{+^d}d!WSvr;l+_L5t37WPD4FDoxtA0Z%emqwWz!Lw&GQt zcJe-Hd0YY8xhAdMZ`;D*tvM~zqU*yBHj~Zu4Vozo52?=;cmK4;%08Q=^jjyDiF}1` zWJ`hhuEvvTMs;#EZ3LxTjoq%zdVIQcS_Z-{GS=sdVIyYC_NgbfI@d}b2!sWkzJ&Hh zQ+AiFJFI~M#%3bmbr$a$tsx^ut8i_y?T9~fLgHKK)3RXeYmeKNpcsY%03~AQqJb&j z8?m{J{Dk(z6q+JD;{!95eUJMC-E`Ys0~1uYYV%bV$Hb>Jw3KfjR!7nHVe;8$08)bj zRqZapk(dKn?yE{yDIeNC zGMApR6L56x)>L(j2GiZxxMUmvb&Yb3%IFo-`f1?3S4;it4RxKd!W_?$MHj5en3b!F ze%1a7-HhV7Amv=e^4dsUVaT)BwCQFb`O#0##k#zfJd4k?IAw@)cf|fm8U8j5ALR_NGWKTU#^aHmwW_$!BLLhh&LF)k{wEGpKU;|J~z?YkELLiz%Mu#U$roQ|t{w&NQ{*q~E6Y@C5j zDhZT1Y*m^KM?bu*upD?PXwdY;qr=IVN$6ihtRAkeexXzDuyX)8>$|qQUz$ z>|^jXTed%S@3+`#p{@?bjTW*e;8Ps5gfJ49a9vAF|Dt!;Vb;4P;*T(%%$@ojHX^7$ zh^2pjZ5oF0WC|_)OHKbA>1ib8j$Nd@knN{&4PM*h_bsDj%=h_#y>_#A^nai)RgCFN zvA|4NQzkLLiG^|W4*$>hrEI}fW5oI*V{HW6a#J>Mr>=*sEwpyAubm`{tE>9F(7frC z@$``sWrXT3HO+&UXx4XhE6@F&vcxuF zFM8L(@NTW*QT;JZc%k{k_2@PiU~~Gk>FiS`Hutnn$NTRC4qtXrGKm!+Sa`? zX}SNz8WBEs&p4xN$+Npm6S1N=zTJh*lkjblXzvFamaFW;P75B(Tg~jt07;tc2yL3a zW`Kx+Bbc$zo=MBRk7@?HBLcS!Oe*xbb3i$J8f{S&PkUM}t#1`NIQcDSwj0mi+2veq zN9$#w-GZI?hC_2y)L&?)u^Z@_ZkRCGEggTnE4|p7NXZ%lzj1@zUE}*4;Ld1AZR6zLN7eJ^+A#(3$qEw#^y)Q-0o&l2 z0&6+bj@e|%T;(_0i0Zsmh89^jqb;eyXHQa?~ceYV)UFh=!?;dU=6rc_jlQAv9+@o2X)20jRgm@)UlE02+l{Tl zUgDjLGqVx@CSeNi8ecR`wmmRds=SpZev)RW_bFJM88g<7t;tis2f)sRGJcz z5!Xb$enzlcbAajrb)IE@8`AvRI%7ddw;Pc+*Y7hs5qMw{ z%0@-l3e{c0>Mge-dhZ;nPY^`DCs#FCE<#i)< z@uu`nx9Xivw!0nrPQ`>^*PHOwVqZ%BYfwIJSF6KT(dsXiAKOm<+(7%J0er7w4~Vy} zIaXT>d}%!SwmJ=UTXu9?>MX6$ky^9!dLJ^)eHO5G7qPN=X4L7e_9gjE55XVUWx5Hg zc3*nZp1vY)G@@nWC4(zh^PwcOB8 z7`T?EtwXS*W%$ZBm(b4J8=KIj)?LrK)V8ir^!yvqZ@ztZXz#W7NByR{Idk|i<+*#6 z=dM$pb7B8S>YChvm61Z%B)zJ4nax2TD`dUv2_LaT9mX1M^sdKrlhYf3tf$_xXEQ;S zUqp4-tMK^PL7<2mG_EfhmlRUoc6SOb;Y+G@*-MwuFz#Fga-pnW+4Ye(NW0cWPQR4< zo%b8+H3xgNPR*2Vf9Jz>*jBb3J?ho{$eVF*D+`Ybe(}+fWIi>L%=YDC4|={crd#al zT~90C^`Dv=(_QU)QX%`!uS<6zh^_e6$=Gu63H{r$jA0nNmP)Kk0Ki0M%PhP5{WECa2RGtvf=2Dd>dS!uS9YM{A%dY{$1z2BoZtymo!ob zjK3z)r(@Emznh^$Yx}uRO@cxvk=zG)EacoLxoRqG_mJcCh|lE%kcU9deUei=JPq6@ zxdidZQMSK?KKGd%V}<>YPlTNNBv&CJ#~_~uIro_y@eD#9208b++#4_Bbi3T=^6`+9 zKKHqN4&-Dr;69gMDaONnE?*8gib%qJE?*6KDCFGda+@gUK9}DKIn@vDbNPLc2SLt# zE`LIlbDzt94LOxL_n92|_FKp!A?H5V{}OV_2kvvZ8Wk)9a_)0^7UWdE+~@K#$f;bo z&*k?(PVsY}%b$UK9OT^R@=nMp{oLpBK2e|hT>dA>&CusQ$t9>i3sB+dy4>gbt0B*T zocmnt7|xxzFWh$SHpAb2&upx!^v@?V^mH7w)(`QIv6?%U6hU?sIvuDCa(xJ0Yif#eI@X zG!n@JqMZ9&{;VkHK9|2P%DKbl5?NS7l?B1bNP*s zQ+vpLl2iJtA)gL8_qqN9kW+iYeXjpFmI?aY=lYS5Q+?w;*H46;@{Rjke+}eRzTD^f6_C$_ocmnA z1#-$C?sNUeAg6lEeXjp9nqWLM?%hht{)9KUaulb8`&_>savkK{=lV}TjwVsUeXjp2$f^D3 zKG*MooXU&)T%Sw{BT{p#J4Jo&bN!v7KKHqP8{`-oNx0ATpMiWl2jl@&Ee+x zc1zxT>jFq}a>@#-t))37Wwz3iyprmi!n_hUKeJJ(vBS4?pxi$`M=4WgRCurep#T5Zj# zzQtC6>wEqj6CgpKvjL?!*IH^Tw&tu8{4pBw`kX>&7!t0%`1wg#{UgcMKa#AxNRoVQ zPIaZV95sy}jky}Ho2y5(#G|oVVi{gNq9q=U)iMZKwsJ&EJQ}Md7IZqYDSgbD1eP_1 z44lXqb0S_ik%kj#V@?=^6RU7y)tD13mkoM)F7tD8ORQC6&QkHRv{pB%vMrbP#MRrur%}9R8wY7PSs7hIm>cP z)zt;%EHc&_TT3eA$}1|%tfeKl7NGek$HgkjbMh+6tE#Q#)vSRo$-|QsSCT@yB~_PD z$}cD^DaXBxz0%Uf*IW}PT!dc}Sur=qI`WHpf@SH&i)+l+&qK4AUr)jLR*_xVY?+K9_I}u zDRH(0K?y0WDX6TDyD6?Fz9cD8@C`|ni)AI{_*s$=Pd~U`A@m9jNpZ{X@%A?OKv(@Gm7scINVBN?sGghMtRTMQzNZV*8mRV5%wK*xZDLG{~ zl~$M7N^e1}&$|gfsOHrqB_bWnL-x=w(GxStim}9KWCb;--&K??qq2)i2?fN8$}{i8 zTo-a*D5cSS8Oe2uGN-EerozI~f}Gk4l!=fArpj`RsiL0itv0$1YRFj?uilVegU;Tw z;$`KS#kv`vIGR>5GbpR9MFq>QGvq8@?3qwj$jA<%uvvBbQQ)7-$+1~D@(W2$Len_X zz5Sa!o8qH*sfMHQ-?iNsTlfC{$}QIM5$~c}g!}z3RiaCl@c-|bI8wmzb8{{=iu0(4 z7k37DemwdO)-{ZFBgJCGyfAWVDlj~xRugUUh!Fi|X+<>`(F2?iht`G87ow3K`N5w1 z+)A{cdDK`ImQ+?%lVZ-gpGr2%+h((u@Y$aZ|C^SrTBU_~O~FiQg5gFut7XPMI)>ek{iWmo6gY~E3|Fv!vU3z4 z$zTUzACSQO3!BU%V;dO&xh@Wwp|$}Lrof#=1c%`WEHcBX;f_P61uh?s4CEB=h3{Y( zMvQdq3+oCQN>w;s*)7CBIv9ho@%vF-9Xpmt6{5Zw`j}Ib2m^E&ay)c%5w9KRDBh3Z z=tX$UyB~hndFbB*eKTaF{}CMNk9iTq?}Hxt1SIGnBmI+bq(A1Z5WgvHH&qo^R8~tQ zd&RJ1;H10M;OCrFJQ^U+ia*f2BT?-;98K_Bz;-q>gzFV<;$gQX-g$U;ZM&d%%5W^}V zSFkdmJfk{GgS(2-n`lriS-c`Ws~8oU9-wTh*!a`I{Ugs$E%Ye?6hEbt@`uWSCTGbo zNcThrMlvSS#0wcwX~K*qNhZNj7&27Tgc2Eu$tX^SbeiOz0;hwEfQy8Sf) z1~R1|8RCX3hSR|9LxnTKoy3D{hf9akz@0$?55jGRGsBICJB^2YADj)&2&aKNfy~flTunp z|Kev_{nycW&p|gIX(zqQJ<|(4^rxepNfh-j_e__1=&y#pLe#(9Grh+{zaIK$kw=uC zpL%wCJ;GcN?{|-QzwaU}g9OA~Jt z@8#O3Gag~av-VLU-7AJELb!|DC+)YR`I0W`)1)F@<8tj&rbn1sF^o|RQ-E-nYoF>o z!nKRxRAM+}&6u{R-9z`JsGANQN*6uvs8W(#bl>nKCp~nD5Qk`_jidGfkE?{4e8l5I zyeFdjknn3zdN1xjY>?2ksO>>BfPRh{awHk7t_gicwRkO_E@}@+H&xWF_Rw`-lrAa< z(lv>??0&?ycb`YNJ~15i{gfVRLU1U#8t#OsN9n&*J0C8FA$@W*X|7`Ri1!T|#7Qme zXy0%KzbGxFe+^s;^zor@CVL8Lp)^qYL}@RzmKWrdVDlE}1c*9si8_=fil?e1pPKYB z{dP3;WzgsOPAQ-qr97ed={Fu9NfF|vvb(9IygC8?QI>WgEg6(AQmKq|DXnxZa?zqs zZU3lYY1F187^$n6F2xHqNeANffrR=F>RGf%E#(W{qXs<;^){ue_b)2Vb56;nf3ZDjd+#}3U zF$|p}$HNrkn1@NlG37Un4yk_886EzovYT|378CwGcOX0yoVTurr+H%kL!Kof&6|*?5DHin$p-cVQ#qWdJ z9$|#{Yl)P<-xffh+8N>fezeDfNYdmH?yz{R3y{kD~7Vino zm*ihD#^Ui_JjSB-guk=>caO2EJjPfT(0G$AYmgdo@lOZ;9~ocSp-(RvbZ=ynqI>&S z+OMV}zG-mN;by=^!_9=FGNHd1xK4}_EpX9r!)T;?;7Z}B!EQ${2bV?Z9 zad;QoyE6YWYKxeNenpLIP%U|I}El+T(id;@FeQg2M5!6=$3v zlI-=+?-%u{+(`d5I21YT@`BB&*F#q=j)59Q-5ya_;G8h$;GxUMc-bg#iuZEkCB4Tu zvQ&(Jsu;$GaAecH#59?t*dyFtG29^P5nZbsj>a+~+9YZA(CrmNVbSZBr z%_6uY(RhpvGezBWQJ2z7x+1u4eBiM#Y4C@~g!%CpPY-1@8%%!{{AfIP zf6I=R37vn$jz)Wj{4;j6gRBh@>}W#SOXW1TLia~$CP(E!K9vRVzG_UHR^gD1?Rj79 zMm$nzP#vIF#B*+l;-WgIVNU{&i{_z(WA0OV0o`KP;`X>K_RMk{%~aMtaT^VV6D%~k z5jzvv#prrezj*rT+H{PrOF}p-C7#*O2+^HCUu31wH;%cMZp^h11nKp(2TA4hY|yp1 z Date: Wed, 13 Dec 2023 16:58:36 +0100 Subject: [PATCH 08/32] Fixing tests + matmul from MFA --- Cargo.toml | 2 +- candle-core/src/metal_backend.rs | 28 +++-- candle-metal-kernels/Cargo.toml | 2 +- candle-metal-kernels/src/tests.rs | 118 ++++++++++++++++++-- candle-transformers/src/models/mixformer.rs | 1 - 5 files changed, 128 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ba09b1d4..7c2e3a7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal = { version = "0.27.0", features = ["mps"]} [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 92c486d6..9866f1ca 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -276,17 +276,17 @@ impl BackendStorage for MetalStorage { self.device.wait_until_completed(); match self.dtype { - DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), + DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))), + DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), + DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), + DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), DType::F32 => { - let vec = buffer.read_to_vec(length / size); + let vec = read_to_vec(&buffer, length / size); // println!("Got back {:?}", &vec[..1]); Ok(CpuStorage::F32(vec)) } - DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), } } @@ -944,6 +944,8 @@ impl BackendStorage for MetalStorage { }; let command_buffer = self.device.command_buffer(); + // println!("MATMUL {b} {m} {n} {k}"); + // println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride()); command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -952,16 +954,17 @@ impl BackendStorage for MetalStorage { name, (b, m, n, k), &lhs_l.stride(), - lhs_l.start_offset(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, &rhs_l.stride(), - rhs_l.start_offset(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, ) .map_err(MetalError::from)?; // Create kernel command_buffer.commit(); + self.device.wait_until_completed(); Ok(Self::new(buffer, self.device.clone(), self.dtype())) } @@ -1138,3 +1141,10 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } } + +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 186f3209..012695dd 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal = { version = "0.27.0", features = ["mps"]} once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 37b07167..8f3e2d43 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -2,6 +2,13 @@ use super::*; use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} + fn new_buffer(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const core::ffi::c_void; @@ -47,7 +54,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + read_to_vec(&output, v.len()) } fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { @@ -72,7 +79,7 @@ fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(x.len()) + read_to_vec(&output, x.len()) } fn run_strided( @@ -103,7 +110,7 @@ fn run_strided( .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -261,7 +268,7 @@ fn cast(v: &[T], name: &'static str) -> Vec { .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -311,7 +318,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + read_to_vec(&output, v.len()) } fn run_affine_strided( @@ -347,7 +354,7 @@ fn run_affine_strided( command_buffer.wait_until_completed(); let len: usize = shape.iter().product(); - output.read_to_vec::(len) + read_to_vec(&output, len) } #[test] @@ -468,7 +475,7 @@ fn run_index_select( command_buffer.commit(); command_buffer.wait_until_completed(); - dst_buffer.read_to_vec::(dst_el) + read_to_vec(&dst_buffer, dst_el) } #[test] @@ -534,7 +541,7 @@ fn index_add() { let expected = vec![ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; - let result = outputs_buffer.read_to_vec::(right.len()); + let result: Vec = read_to_vec(&outputs_buffer, right.len()); assert_eq!(result, expected); } @@ -574,7 +581,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(out_length) + read_to_vec(&output, out_length) } fn run_softmax(v: &[T], last_dim: usize, name: &'static str) -> Vec { @@ -598,7 +605,7 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -720,7 +727,7 @@ fn run_where_cond( command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(length) + read_to_vec(&output, length) } #[test] @@ -744,3 +751,92 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } + +fn run_gemm( + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: Vec, + lhs_offset: usize, + rhs: &[T], + rhs_stride: Vec, + rhs_offset: usize, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_gemm( + &device, + command_buffer, + &kernels, + "sgemm", + (b, m, n, k), + &lhs_stride, + lhs_offset, + &lhs, + &rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); +} diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 8e16e6a9..3f9aa47d 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -144,7 +144,6 @@ impl RotaryEmbedding { let freqs = t.matmul(&inv_freq)?; let sin = freqs.sin()?; let cos = freqs.cos()?; - // todo!("{}", sin); Ok(Self { sin, cos }) } From 361f2ad2af52ccf1750e274f1649fb8c90f80a86 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 16:05:33 +0100 Subject: [PATCH 09/32] Working with merging encoders and using fences. --- candle-core/src/metal_backend.rs | 120 ++++------------ candle-core/tests/tensor_tests.rs | 2 + candle-metal-kernels/src/lib.rs | 40 +++++- candle-metal-kernels/src/test.swift | 209 ++++++++++++++++++++++++++++ candle-nn/src/ops.rs | 2 - 5 files changed, 279 insertions(+), 94 deletions(-) create mode 100644 candle-metal-kernels/src/test.swift diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 9866f1ca..4bc80823 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,6 +38,7 @@ pub struct MetalDevice { command_queue: metal::CommandQueue, command_buffers: Arc>>, command_buffer_index: Arc>, + fence: metal::Fence, kernels: Arc, buffers: Arc>>>>, } @@ -71,68 +72,32 @@ impl MetalDevice { pub fn command_buffer(&self) -> CommandBuffer { let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut command_buffer = command_buffers[0].to_owned(); let mut index = self.command_buffer_index.try_write().unwrap(); - let n = command_buffers.len(); - if *index == n { - // todo!("Cycle buffers"); - for i in 0..n { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => { - // println!("Wait during cycling {i}"); - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - command_buffer.wait_until_completed(); - } - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer {i} not committed during cycling"); - } - } - } - let new_buffers = (0..n) - .map(|i| { - // println!("Creating command buffer {i}"); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - command_buffer.set_label(&format!("num {i}")); - command_buffer.enqueue(); - command_buffer - }) - .collect(); - *command_buffers = new_buffers; + if *index > 20 { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffers = vec![command_buffer.clone()]; *index = 0; - // println!("Reset"); } - // println!("Giving buffer {} / {n}", *index); - let out = &command_buffers[*index]; - assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); *index += 1; - out.to_owned() + command_buffer } pub fn wait_until_completed(&self) { - let command_buffers = self.command_buffers.try_write().unwrap(); - let index = self.command_buffer_index.try_write().unwrap(); - // let n = command_buffers.len(); - // for i in 0..*index { - // let command_buffer = &command_buffers[i]; - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - // } - for i in 0..*index { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => {} - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer not committed"); - } + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let command_buffer = &command_buffers[0]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Alredy committed"); } - // println!("Wait {i}"); - command_buffer.wait_until_completed(); - // println!("Ok {i}"); - // command_buffer.wait_until_completed(); + _ => {} } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()]; } pub fn kernels(&self) -> &Kernels { @@ -176,7 +141,7 @@ impl MetalDevice { } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed") + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -184,7 +149,7 @@ impl MetalDevice { let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, size, - metal::MTLResourceOptions::StorageModeShared, + metal::MTLResourceOptions::StorageModeManaged, ); let real = self._new_buffer( size, @@ -194,15 +159,15 @@ impl MetalDevice { let command_buffer = self.command_buffer(); command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); blit.end_encoding(); - command_buffer.commit(); - drop(command_buffer); + // drop(command_buffer); // real.did_modify_range(metal::NSRange::new(0, real.length())); // println!("Command {:?}", command.status()); - // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage { self.dtype ); } - self.device.wait_until_completed(); - self.buffer - .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); { let command_buffer = self.device.command_buffer(); command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); + blit.wait_for_fence(&self.device.fence); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.update_fence(&self.device.fence); blit.end_encoding(); - - command_buffer.commit(); } self.device.wait_until_completed(); @@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype"); - device.wait_until_completed(); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { @@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("to_dtype"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - device.wait_until_completed(); - Ok(Self::new(buffer, device.clone(), dtype)) } @@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; // Create kernel - command_buffer.commit(); - self.device.wait_until_completed(); Ok(Self::new(buffer, self.device.clone(), self.dtype())) } @@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - command_buffer.commit(); Ok(()) } } @@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice { // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); - let n = 64; + let n = 1; let command_queue = device.new_command_queue(); let command_buffers = (0..n) @@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice { .collect(); let command_buffers = Arc::new(RwLock::new(command_buffers)); let command_buffer_index = Arc::new(RwLock::new(0)); - let kernels = Arc::new(Kernels::new()); + let fence = device.new_fence(); + let kernels = Arc::new(Kernels::new(fence.clone())); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, + fence, command_queue, command_buffers, command_buffer_index, @@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice { 0, ); blit.end_encoding(); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..a77f9c3a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> { let b = Tensor::from_slice(&data, (2, 2), device)?; let c = a.matmul(&b)?; + let d = a.matmul(&c)?; assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + assert_eq!(d.to_vec2::()?, &[[37.0, 54.0], [81.0, 118.0]]); let data = vec![1.0f32, 2.0]; let a = Tensor::from_slice(&data, (2, 1), device)?; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index b80dcb79..01432ccb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -184,19 +184,21 @@ impl From> for MetalKernelError { type Libraries = HashMap; type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock, pipelines: RwLock, + fence: metal::Fence, } impl Kernels { - pub fn new() -> Self { + pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, + fence, } } @@ -304,12 +306,14 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -331,6 +335,7 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -350,6 +355,7 @@ pub fn call_unary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -368,6 +374,7 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -375,6 +382,7 @@ pub fn call_binary_contiguous( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -399,6 +407,7 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -420,6 +429,7 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -438,12 +448,14 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -463,6 +475,7 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -482,6 +495,7 @@ pub fn call_cast_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -501,6 +515,7 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -527,6 +542,7 @@ pub fn call_reduce_contiguous( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -544,6 +560,7 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, elements_to_sum, input, output)); @@ -569,6 +586,7 @@ pub fn call_last_softmax( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -588,12 +606,14 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -616,6 +636,7 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -634,6 +655,7 @@ pub fn call_affine_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -652,12 +674,14 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -679,6 +703,7 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -696,6 +721,7 @@ pub fn call_powf_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -714,12 +740,14 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -741,6 +769,7 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -758,6 +787,7 @@ pub fn call_elu_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -779,6 +809,7 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -803,6 +834,7 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -829,6 +861,7 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -848,6 +881,7 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1045,6 +1079,7 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); // println!("Threadgroup {block_bytes}"); encoder.set_threadgroup_memory_length(0, block_bytes.into()); @@ -1087,6 +1122,7 @@ pub fn call_gemm( }; // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.dispatch_thread_groups(grid_size, group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift new file mode 100644 index 00000000..f9bb9f91 --- /dev/null +++ b/candle-metal-kernels/src/test.swift @@ -0,0 +1,209 @@ + +import Metal +import MetalPerformanceShadersGraph + + + +let type = MTLDataType.float; +let dataType = type; +var B = 2; +var M = 2; +var N = 2; +var K = 2; +var A_trans = false; +var B_trans = false; +var D_trans = false; +var alpha = Float(1.0); +var beta = Float(0.0); +var batched = B > 1; +var fused_activation = false; +var fused_bias = false; +let constants = MTLFunctionConstantValues() +constants.setConstantValue(&M, type: .uint, index: 0) +constants.setConstantValue(&N, type: .uint, index: 1) +constants.setConstantValue(&K, type: .uint, index: 2) +constants.setConstantValue(&A_trans, type: .bool, index: 10) +constants.setConstantValue(&B_trans, type: .bool, index: 11) +constants.setConstantValue(&D_trans, type: .bool, index: 13) +constants.setConstantValue(&alpha, type: .float, index: 20) +constants.setConstantValue(&beta, type: .float, index: 21) +constants.setConstantValue(&batched, type: .bool, index: 100) +constants.setConstantValue(&fused_activation, type: .bool, index: 101) +constants.setConstantValue(&fused_bias, type: .bool, index: 50001) + + +var M_simd = UInt16(16) +var N_simd = UInt16(16) +var K_simd = UInt16(32) +var M_splits = UInt16(2) +var N_splits = UInt16(2) +constants.setConstantValue(&M_simd, type: .ushort, index: 200) +constants.setConstantValue(&N_simd, type: .ushort, index: 201) +constants.setConstantValue(&K_simd, type: .ushort, index: 202) +constants.setConstantValue(&M_splits, type: .ushort, index: 210) +constants.setConstantValue(&N_splits, type: .ushort, index: 211) + +let M_group = M_simd * M_splits +let N_group = N_simd * N_splits + +// Satisfy Metal API validation. +#if DEBUG +do { + var garbage: SIMD4 = .zero + constants.setConstantValue(&garbage, type: .bool, index: 102) + constants.setConstantValue(&garbage, type: .bool, index: 103) + constants.setConstantValue(&garbage, type: .bool, index: 113) + constants.setConstantValue(&garbage, type: .bool, index: 50000) +} +#endif + +let device = MTLCopyAllDevices().first! +device.shouldMaximizeConcurrentCompilation = true + +var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; +libraryURL.append(component: "src") +libraryURL.append(component: "libMetalFlashAttention.metallib") +let library = try! device.makeLibrary(URL: libraryURL) + +var name: String + switch dataType { + case .half: name = "hgemm" + case .float: name = "sgemm" + default: fatalError() + } +let function = try! library.makeFunction( + name: name, constantValues: constants) + +let A_block_length = M_group * K_simd +let B_block_length = K_simd * N_group + +var blockElements = A_block_length + B_block_length; +if (M % 8 != 0) && (N % 8 != 0) { + let C_block_length = M_group * N_group; + blockElements = max(C_block_length, blockElements) +} +if fused_bias { + if D_trans { + blockElements = max(blockElements, M_group) + } else { + blockElements = max(blockElements, N_group) + } +} +// let blockBytes = blockElements * UInt16(dataType.size) +let elementSize = 4 +let blockBytes = blockElements * UInt16(elementSize) + +func ceilDivide(target: Int, granularity: UInt16) -> Int { + (target + Int(granularity) - 1) / Int(granularity) +} +var gridSize = MTLSize( + width: ceilDivide(target: N, granularity: N_group), + height: ceilDivide(target: M, granularity: M_group), + depth: 1) +let groupSize = MTLSize( + width: Int(32 * M_splits * N_splits), + height: 1, + depth: 1) + +let commandQueue = device.makeCommandQueue()! + +let threadgroupMemoryLength = blockBytes; + +let rowsA = M; +let columnsA = K; +let rowsB = K; +let columnsB = N; +let rowsC = M; +let columnsC = N; +var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) + +var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) + +var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) +var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC) +for i in 0...stride, options: [])! + +let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout.stride, options: [])! + +let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! +let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! + + +let pipeline = try device.makeComputePipelineState(function: function) + +func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){ + let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! + encoder.setComputePipelineState(pipeline) + encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) + + encoder.setBuffer(bufferA, offset: 0, index: 0) + encoder.setBuffer(bufferB, offset: 0, index: 1) + encoder.setBuffer(bufferC, offset: 0, index: 2) + let gridZ: Int = B + if batched{ + func byteStride(shape: [Int]) -> Int { + let rank = shape.count + var output = elementSize * shape[rank - 2] * shape[rank - 1] + if shape.dropLast(2).reduce(1, *) == 1 { + output = 0 + } + return output + } + let byteStrideA = M*K*elementSize + let byteStrideB = N*K*elementSize + let byteStrideC = M*N*elementSize + + let byteStrideD = 0 + withUnsafeTemporaryAllocation( + of: SIMD4.self, capacity: gridZ + ) { buffer in + for i in 0..>.stride + assert(MemoryLayout>.stride == 8 * 4) + encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) + } + } + gridSize.depth = gridZ + + + encoder.dispatchThreadgroups( + gridSize, threadsPerThreadgroup: groupSize + ) + encoder.endEncoding() +} + +var commandBuffer = commandQueue.makeCommandBuffer()! +call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC) +commandBuffer.commit() +commandBuffer = commandQueue.makeCommandBuffer()! +commandBuffer.encodeWaitForEvent(event, value: 2) +call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD) +commandBuffer.commit() + +commandBuffer.waitUntilCompleted() +var contents = bufferC.contents(); +var count = B * rowsA * columnsB; +var typedPointer = contents.bindMemory(to: Float.self, capacity: count) +var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("First matmul is OK", Array(bufferedPointer)) + +contents = bufferD.contents(); +count = B * rowsA * columnsB; +typedPointer = contents.bindMemory(to: Float.self, capacity: count) +bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("This should be filled", Array(bufferedPointer)) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 14dd10de..e002d931 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); - command_buffer.commit(); - output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } From f419a38e1ad431cac245e0d7525b2c278660df18 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 16:52:37 +0100 Subject: [PATCH 10/32] Fix use resource. --- candle-metal-kernels/src/lib.rs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 01432ccb..0c383dec 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -312,6 +312,8 @@ pub fn call_unary_contiguous( set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -354,6 +356,8 @@ pub fn call_unary_strided( let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -381,6 +385,9 @@ pub fn call_binary_contiguous( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -428,6 +435,9 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(left_input, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -454,6 +464,8 @@ pub fn call_cast_contiguous( set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -494,6 +506,8 @@ pub fn call_cast_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -541,6 +555,8 @@ pub fn call_reduce_contiguous( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -585,6 +601,8 @@ pub fn call_last_softmax( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -612,6 +630,8 @@ pub fn call_affine( set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -654,6 +674,8 @@ pub fn call_affine_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -680,6 +702,8 @@ pub fn call_powf( set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -720,6 +744,8 @@ pub fn call_powf_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -746,6 +772,8 @@ pub fn call_elu( set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -786,6 +814,8 @@ pub fn call_elu_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -833,6 +863,10 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(cond, metal::MTLResourceUsage::Read); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -880,6 +914,9 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -1121,6 +1158,9 @@ pub fn call_gemm( depth: 1, }; // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); From 4eeaf205d6d0577805a41dc7ae2457be1862726a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 19:37:03 +0100 Subject: [PATCH 11/32] Fix softmax for long sequences (missing barrier). --- candle-core/src/metal_backend.rs | 2 +- candle-metal-kernels/src/reduce.metal | 15 ++++---- candle-metal-kernels/src/tests.rs | 51 +++++++++++++++++++++------ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 4bc80823..d38796a1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -126,7 +126,7 @@ impl MetalDevice { } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); - // subbuffers.push(new_buffer.clone()); + subbuffers.push(new_buffer.clone()); // println!("Created tensor {size} {name}"); for subbuffers in buffers.values_mut() { let newbuffers = subbuffers diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3a402427..53e4664a 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -32,7 +32,7 @@ kernel void NAME( \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ \ shared_memory[tid] = 0; \ /* \ @@ -67,6 +67,7 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ + threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -95,10 +96,12 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ + float tmp = 0; \ while (idx < stop_idx) { \ - shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \ + tmp = MAX(tmp, src[idx]); \ idx += block_dim; \ } \ + shared_memory[tid] = tmp; \ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ @@ -112,12 +115,13 @@ kernel void NAME( \ float _max = shared_memory[0]; \ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ idx = start_idx + tid; \ while (idx < stop_idx) { \ - const T val = T(exp(src[idx] - _max)); \ - dst[idx] = val; \ + const float val = exp(float(src[idx]) - _max); \ + dst[idx] = T(val); \ shared_memory[tid] += val; \ idx += block_dim; \ } \ @@ -125,10 +129,9 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ - const T inv_acc = T(1/shared_memory[0]); \ + const T inv_acc = T(1.0/shared_memory[0]); \ idx = start_idx + tid; \ while (idx < stop_idx) { \ dst[idx] *= inv_acc; \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8f3e2d43..75c2f013 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -37,7 +37,8 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -59,7 +60,8 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -94,7 +96,8 @@ fn run_strided( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let output = new_buffer(&device, v); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_unary_strided( &device, command_buffer, @@ -247,7 +250,8 @@ fn binary_add_f32() { fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -294,7 +298,8 @@ fn cast_u32_f32() { fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -329,7 +334,8 @@ fn run_affine_strided( add: f64, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -457,7 +463,8 @@ fn run_index_select( _ => unimplemented!(), }; - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_index_select( &device, &command_buffer, @@ -559,7 +566,8 @@ fn cos_f16() { fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -586,7 +594,8 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -636,6 +645,24 @@ fn softmax() { vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] ); + let last_dim = 4096; + let n = 200; + let mut v = vec![0.0; n * last_dim]; + for i in 0..n { + v[i * last_dim] = 20.0; + } + let results = run_softmax(&v, last_dim, "softmax_float"); + let results = approx(results, 4); + println!("{results:?}"); + assert_eq!( + results.iter().map(|&s| s.round() as usize).sum::(), + n + ); + assert_eq!(results[0], 1.0); + assert_eq!(results[1], 0.0); + assert_eq!(results[last_dim], 1.0); + assert_eq!(results[2 * last_dim], 1.0); + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; let results = run_softmax(&v, last_dim, "softmax_float"); @@ -686,7 +713,8 @@ fn run_where_cond( name: &'static str, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -762,7 +790,8 @@ fn run_gemm( rhs_offset: usize, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; From ece4c69a681215837fd5a008e2ee652394daa8ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:35:08 +0100 Subject: [PATCH 12/32] Fixing softmax. --- candle-core/src/metal_backend.rs | 10 ++++++---- candle-metal-kernels/src/reduce.metal | 11 +++++++---- candle-nn/src/ops.rs | 2 +- candle-transformers/src/models/mixformer.rs | 4 ---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d38796a1..b8b951f0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -113,21 +113,23 @@ impl MetalDevice { self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc { - // println!("Creating new buffer {name}"); + fn _new_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Arc { let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - // println!("Reusing tensor {size} {name}"); return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); subbuffers.push(new_buffer.clone()); - // println!("Created tensor {size} {name}"); for subbuffers in buffers.values_mut() { let newbuffers = subbuffers .iter() diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 53e4664a..3633fdcf 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -67,7 +67,6 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ - threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -94,11 +93,10 @@ kernel void NAME( size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t idx = start_idx + tid; \ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ \ - float tmp = 0; \ + float tmp = -INFINITY; \ while (idx < stop_idx) { \ - tmp = MAX(tmp, src[idx]); \ + tmp = MAX(tmp, float(src[idx])); \ idx += block_dim; \ } \ shared_memory[tid] = tmp; \ @@ -109,12 +107,15 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ + /* wait for shared_memory[0] to be filled */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ float _max = shared_memory[0]; \ \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ @@ -125,10 +126,12 @@ kernel void NAME( shared_memory[tid] += val; \ idx += block_dim; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ const T inv_acc = T(1.0/shared_memory[0]); \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e002d931..f00d8e2f 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 3f9aa47d..e4e4f619 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -272,10 +272,6 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - // let view = xs.to_string(); - // if view.contains("NaN") { - // panic!("NaN"); - // } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self From 40c3e1bd5ae27134c413e993695a1d343e2b270b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:41:14 +0100 Subject: [PATCH 13/32] cleanup. --- candle-core/src/metal_backend.rs | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b8b951f0..b24db020 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -91,7 +91,7 @@ impl MetalDevice { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled | metal::MTLCommandBufferStatus::Completed => { - panic!("Alredy committed"); + panic!("Already committed"); } _ => {} } @@ -166,9 +166,6 @@ impl MetalDevice { blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.update_fence(&self.fence); blit.end_encoding(); - // drop(command_buffer); - // real.did_modify_range(metal::NSRange::new(0, real.length())); - // println!("Command {:?}", command.status()); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. @@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage { DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), - DType::F32 => { - let vec = read_to_vec(&buffer, length / size); - // println!("Got back {:?}", &vec[..1]); - Ok(CpuStorage::F32(vec)) - } + DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), } } @@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage { let kernel_name = match (B::KERNEL, dtype) { ("add", DType::F32) => contiguous::add::FLOAT, - // ("badd", DType::F32) => contiguous::add::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT, - //("bsub", DType::F32) => contiguous::sub::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT, - // ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, - // ("bdiv", DType::F32) => contiguous::div::FLOAT, ("add", DType::F16) => contiguous::add::HALF, - // ("badd", DType::F16) => contiguous::add::HALF, ("sub", DType::F16) => contiguous::sub::HALF, - // ("bsub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, - // ("bmul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - // ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - // Create descriptors - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); let name = match self.dtype { DType::F32 => "sgemm", @@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage { }; let command_buffer = self.device.command_buffer(); - // println!("MATMUL {b} {m} {n} {k}"); - // println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride()); command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - // Create kernel - Ok(Self::new(buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer(); - // println!("Copy strided"); if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); @@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice { type Storage = MetalStorage; fn new(ordinal: usize) -> Result { - // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); let n = 1; @@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice { let command_buffer = self.command_buffer(); command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); blit.fill_buffer( &buffer, metal::NSRange { @@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice { }, 0, ); + blit.update_fence(&self.fence); blit.end_encoding(); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } From cf27868b57f20cba869b4ca425b0b9c09c724822 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:44:22 +0100 Subject: [PATCH 14/32] More cleanup. --- candle-metal-kernels/src/lib.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0c383dec..60f9b8a6 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -173,6 +173,12 @@ pub enum MetalKernelError { FailedToCreateComputeFunction, #[error("Failed to create pipeline")] FailedToCreatePipeline(String), + #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, } impl From> for MetalKernelError { @@ -1029,24 +1035,22 @@ pub fn call_gemm( } else if lhs_m1 == m && lhs_m2 == 1 { true } else { - todo!(); - // Err(MetalError::MatMulNonContiguous { - // lhs_stride: lhs_stride.to_vec(), - // rhs_stride: rhs_stride.to_vec(), - // mnk: (m, n, k), - // })? + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; }; let b_trans = if rhs_m1 == 1 && rhs_m2 == n { false } else if rhs_m1 == k && rhs_m2 == 1 { true } else { - todo!(); - // Err(MetalError::MatMulNonContiguous { - // lhs_stride: lhs_stride.to_vec(), - // rhs_stride: rhs_stride.to_vec(), - // mnk: (m, n, k), - // })? + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; }; let d_trans = false; let alpha = 1.0f32; @@ -1083,7 +1087,6 @@ pub fn call_gemm( (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - // println!("Constants {constants:?}"); let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; let m_group = m_simd * m_splits; let n_group = n_simd * n_splits; @@ -1103,7 +1106,6 @@ pub fn call_gemm( block_elements = std::cmp::max(block_elements, n_group); } } - // TODO adapt for f16 let bytes = match name { "sgemm" => 4, "hgemm" => 2, @@ -1118,7 +1120,6 @@ pub fn call_gemm( let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - // println!("Threadgroup {block_bytes}"); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); From 243e83f2b936d49219c92aa3c2f73a0c06c8cb13 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:02:41 +0100 Subject: [PATCH 15/32] Adding a bunch of docs ! Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> --- candle-core/src/metal_backend.rs | 170 ++++++++++++++++++++----------- candle-metal-kernels/src/lib.rs | 17 ++++ 2 files changed, 128 insertions(+), 59 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b24db020..d8518b3e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -34,12 +34,48 @@ impl From for MetalError { #[derive(Clone)] pub struct MetalDevice { + /// Raw metal device: device: metal::Device, + + /// Single command queue for the entire device. command_queue: metal::CommandQueue, - command_buffers: Arc>>, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: Arc>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. command_buffer_index: Arc>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, + /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the + /// execution order to be linear. + /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the + /// compute graph. fence: metal::Fence, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`], both fences need to match kernels: Arc, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers + /// (strong_count = 1). buffers: Arc>>>>, } @@ -71,13 +107,13 @@ impl MetalDevice { } pub fn command_buffer(&self) -> CommandBuffer { - let mut command_buffers = self.command_buffers.try_write().unwrap(); - let mut command_buffer = command_buffers[0].to_owned(); + let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); + let mut command_buffer = command_buffer_lock.to_owned(); let mut index = self.command_buffer_index.try_write().unwrap(); - if *index > 20 { + if *index > self.compute_per_buffer { command_buffer.commit(); command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffers = vec![command_buffer.clone()]; + *command_buffer_lock = command_buffer.clone(); *index = 0; } *index += 1; @@ -85,8 +121,7 @@ impl MetalDevice { } pub fn wait_until_completed(&self) { - let mut command_buffers = self.command_buffers.try_write().unwrap(); - let command_buffer = &command_buffers[0]; + let mut command_buffer = self.command_buffer.try_write().unwrap(); match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -97,7 +132,7 @@ impl MetalDevice { } command_buffer.commit(); command_buffer.wait_until_completed(); - *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()]; + *command_buffer = self.command_queue.new_command_buffer().to_owned(); } pub fn kernels(&self) -> &Kernels { @@ -108,12 +143,65 @@ impl MetalDevice { &self.device } + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer( + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// This method will block the computation because of the + /// lack of lifetime management through the GPU. + /// Internal comment for technical details. + pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self.allocate_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + ); + let command_buffer = self.command_buffer(); + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.set_label("with_data_blit"); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); + blit.end_encoding(); + + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed(); + real + } + + /// The critical allocator algorithm + fn allocate_buffer( &self, size: NSUInteger, option: MTLResourceOptions, @@ -142,42 +230,7 @@ impl MetalDevice { new_buffer } - pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") - } - - pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { - let size = core::mem::size_of_val(data) as NSUInteger; - let tmp = self.device.new_buffer_with_data( - data.as_ptr() as *const core::ffi::c_void, - size, - metal::MTLResourceOptions::StorageModeManaged, - ); - let real = self._new_buffer( - size, - metal::MTLResourceOptions::StorageModePrivate, - "with_data", - ); - let command_buffer = self.command_buffer(); - command_buffer.set_label("with_data"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.set_label("with_data_blit"); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.update_fence(&self.fence); - blit.end_encoding(); - - // This is necessary, for mmaped safetensors - // Because of the unsafe slice cast we're doing. - // The slice might not live long enough for metal - // To actually fill the GPU buffer. - // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage todo - // deallocate properly. - self.wait_until_completed(); - real - } - + /// Create a metal GPU capture trace on [`path`]. pub fn capture>(&self, path: P) -> Result<()> { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); @@ -194,8 +247,11 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { + /// The actual buffer containing the data. buffer: Arc, + /// a reference to the device owning this buffer device: MetalDevice, + /// The dtype is kept since buffers are untyped. dtype: DType, } @@ -952,29 +1008,25 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - - let n = 1; let command_queue = device.new_command_queue(); - - let command_buffers = (0..n) - .map(|i| { - let command_buffer = command_queue.new_command_buffer().to_owned(); - command_buffer.enqueue(); - command_buffer.set_label(&format!("num {i}")); - command_buffer - }) - .collect(); - let command_buffers = Arc::new(RwLock::new(command_buffers)); + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); let fence = device.new_fence(); let kernels = Arc::new(Kernels::new(fence.clone())); let buffers = Arc::new(RwLock::new(HashMap::new())); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 20, + }; Ok(Self { device, fence, command_queue, - command_buffers, + command_buffer, command_buffer_index, + compute_per_buffer, buffers, kernels, }) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 60f9b8a6..2fa571bc 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); @@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } @@ -220,6 +228,9 @@ impl Kernels { Source::Mfa => panic!("Invalid lib"), } } + + /// Load the give library from its [`source`]. + /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, @@ -262,6 +273,9 @@ impl Kernels { Ok(func) } + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source fn load_pipeline_with_constants( &self, device: &Device, @@ -290,6 +304,9 @@ impl Kernels { } } + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source (without constants) pub fn load_pipeline( &self, device: &Device, From 916a8c54646fab67f3d886717a48abfe55d89e39 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:15:21 +0100 Subject: [PATCH 16/32] Revert candle-transformers. --- candle-transformers/src/models/mixformer.rs | 42 ++------------------- 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e4e4f619..e822ca14 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,9 +142,10 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let sin = freqs.sin()?; - let cos = freqs.cos()?; - Ok(Self { sin, cos }) + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) } fn apply_rotary_emb_qkv( @@ -407,38 +408,3 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_rotary() { - let dev = Device::new_metal(0).unwrap(); - for i in 0..10000 { - let dim = 8; - let max_seq_len = 12; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); - let t = Tensor::arange(0u32, max_seq_len as u32, &dev) - .unwrap() - .to_dtype(DType::F32) - .unwrap() - .reshape((max_seq_len, 1)) - .unwrap(); - let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 1.0); - let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.1); - let freqs = t.matmul(&inv_freq).unwrap(); - let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.1); - let sin = freqs.sin().unwrap().contiguous().unwrap(); - let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.099833414); - } - } -} From 77197379cccda06f895da0c1b3fcb35d19e90bd3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:17:05 +0100 Subject: [PATCH 17/32] More cleanup. --- candle-core/src/tensor.rs | 5 +---- candle-core/tests/tensor_tests.rs | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 73a0cc7a..e478869a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1863,10 +1863,7 @@ impl Tensor { Storage::Metal(metal.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), - (Storage::Metal(storage), Device::Cpu) => { - // println!("{storage:?} - {:?}", storage.to_cpu_storage()?); - Storage::Cpu(storage.to_cpu_storage()?) - } + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a77f9c3a..c871dc96 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -900,9 +900,7 @@ fn matmul(device: &Device) -> Result<()> { let b = Tensor::from_slice(&data, (2, 2), device)?; let c = a.matmul(&b)?; - let d = a.matmul(&c)?; assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); - assert_eq!(d.to_vec2::()?, &[[37.0, 54.0], [81.0, 118.0]]); let data = vec![1.0f32, 2.0]; let a = Tensor::from_slice(&data, (2, 1), device)?; From 34d83377f6e63ef428c82448515dbba0047fdfae Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:18:54 +0100 Subject: [PATCH 18/32] Better error message on older macos --- candle-metal-kernels/src/lib.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2fa571bc..514cf33e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -243,9 +243,11 @@ impl Kernels { let lib = match source { Source::Mfa => { let source_data = MFA; - device - .new_library_with_data(source_data) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + })? } source => { let source_content = self.get_library_source(source); From 26540641c1f0a7b351f5e3d3c3c165221ae1d9ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:24:47 +0100 Subject: [PATCH 19/32] Renamed all kernel names. --- candle-core/src/metal_backend.rs | 34 +++++++++++++-------------- candle-metal-kernels/src/affine.metal | 18 +++++++------- candle-metal-kernels/src/binary.metal | 6 ++--- candle-metal-kernels/src/lib.rs | 24 +++++++++---------- candle-metal-kernels/src/reduce.metal | 12 +++++----- candle-metal-kernels/src/unary.metal | 12 +++++----- candle-nn/src/ops.rs | 6 ++--- 7 files changed, 56 insertions(+), 56 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d8518b3e..b4a490cd 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "affine_float", - DType::F16 => "affine_half", + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine( @@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "affine_float_strided", - DType::F16 => "affine_half_strided", + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine_strided( @@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "powf_float", - DType::F16 => "powf_half", + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf( @@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "powf_float_strided", - DType::F16 => "powf_half_strided", + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf_strided( @@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "elu_float", - DType::F16 => "elu_half", + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu( @@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "elu_float_strided", - DType::F16 => "elu_half_strided", + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu_strided( @@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 18adb457..4166d811 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \ } \ -AFFINE(affine_float, float) -AFFINE(affine_half, half) -POWF(powf_float, float) -POWF(powf_half, half) -ELU(elu_float, float) -ELU(elu_half, half) +AFFINE(affine_f32, float) +AFFINE(affine_f16, half) +POWF(powf_f32, float) +POWF(powf_f16, half) +ELU(elu_f32, float) +ELU(elu_f16, half) #if __METAL_VERSION__ >= 310 -AFFINE(affine_bfloat, bfloat); -POWF(powf_bfloat, bfloat); -ELU(elu_bfloat, bfloat); +AFFINE(affine_bf16, bfloat); +POWF(powf_bf16, bfloat); +ELU(elu_bf16, bfloat); #endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f18cdbb0..ea21bb34 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \ } #define BINARY_OP(FN, NAME) \ -BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ -BINARY(FN, half, half, NAME##_half, NAME##_half_strided); +BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); BINARY_OP(x + y, add) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 514cf33e..a23aa47c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -125,16 +125,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float"); - pub const HALF: Kernel = Kernel("copy_half"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -145,16 +145,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float_strided"); - pub const HALF: Kernel = Kernel("copy_half_strided"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3633fdcf..62443660 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -71,9 +71,9 @@ kernel void NAME( \ } \ -REDUCE(x + y, fast_sum_float, float) -REDUCE(x * y, fast_mul_float, float) -REDUCE(max(x, y), fast_max_float, float) +REDUCE(x + y, fast_sum_f32, float) +REDUCE(x * y, fast_mul_f32, float) +REDUCE(max(x, y), fast_max_f32, float) #define SOFTMAX(NAME, T) \ kernel void NAME( \ @@ -142,8 +142,8 @@ kernel void NAME( } \ } \ -SOFTMAX(softmax_float, float) -SOFTMAX(softmax_half, half) +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 310 -SOFTMAX(softmax_bfloat, bfloat) +SOFTMAX(softmax_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 765b14a5..553bc506 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \ } #define UNARY_OP(NAME) \ -UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ -UNARY(NAME, half, NAME##_half, NAME##_half_strided); +UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ +UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ -UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); UNARY_OP(cos) @@ -108,8 +108,8 @@ UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) -UNARY(id, float, copy_float, copy_float_strided) -UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, float, copy_f32, copy_f32_strided) +UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) @@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) -UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) +UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index f00d8e2f..ca23f90e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim { let command_buffer = device.command_buffer(); let kernels = device.kernels(); let name = match storage.dtype() { - DType::F32 => "softmax_float", - DType::F16 => "softmax_half", - DType::BF16 => "softmax_bfloat", + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), }; From 8b5059e95178cd0bf369906717319b8eef2cd8a8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:55:30 +0100 Subject: [PATCH 20/32] Remove test file. --- candle-metal-kernels/src/test.swift | 209 ---------------------------- 1 file changed, 209 deletions(-) delete mode 100644 candle-metal-kernels/src/test.swift diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift deleted file mode 100644 index f9bb9f91..00000000 --- a/candle-metal-kernels/src/test.swift +++ /dev/null @@ -1,209 +0,0 @@ - -import Metal -import MetalPerformanceShadersGraph - - - -let type = MTLDataType.float; -let dataType = type; -var B = 2; -var M = 2; -var N = 2; -var K = 2; -var A_trans = false; -var B_trans = false; -var D_trans = false; -var alpha = Float(1.0); -var beta = Float(0.0); -var batched = B > 1; -var fused_activation = false; -var fused_bias = false; -let constants = MTLFunctionConstantValues() -constants.setConstantValue(&M, type: .uint, index: 0) -constants.setConstantValue(&N, type: .uint, index: 1) -constants.setConstantValue(&K, type: .uint, index: 2) -constants.setConstantValue(&A_trans, type: .bool, index: 10) -constants.setConstantValue(&B_trans, type: .bool, index: 11) -constants.setConstantValue(&D_trans, type: .bool, index: 13) -constants.setConstantValue(&alpha, type: .float, index: 20) -constants.setConstantValue(&beta, type: .float, index: 21) -constants.setConstantValue(&batched, type: .bool, index: 100) -constants.setConstantValue(&fused_activation, type: .bool, index: 101) -constants.setConstantValue(&fused_bias, type: .bool, index: 50001) - - -var M_simd = UInt16(16) -var N_simd = UInt16(16) -var K_simd = UInt16(32) -var M_splits = UInt16(2) -var N_splits = UInt16(2) -constants.setConstantValue(&M_simd, type: .ushort, index: 200) -constants.setConstantValue(&N_simd, type: .ushort, index: 201) -constants.setConstantValue(&K_simd, type: .ushort, index: 202) -constants.setConstantValue(&M_splits, type: .ushort, index: 210) -constants.setConstantValue(&N_splits, type: .ushort, index: 211) - -let M_group = M_simd * M_splits -let N_group = N_simd * N_splits - -// Satisfy Metal API validation. -#if DEBUG -do { - var garbage: SIMD4 = .zero - constants.setConstantValue(&garbage, type: .bool, index: 102) - constants.setConstantValue(&garbage, type: .bool, index: 103) - constants.setConstantValue(&garbage, type: .bool, index: 113) - constants.setConstantValue(&garbage, type: .bool, index: 50000) -} -#endif - -let device = MTLCopyAllDevices().first! -device.shouldMaximizeConcurrentCompilation = true - -var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; -libraryURL.append(component: "src") -libraryURL.append(component: "libMetalFlashAttention.metallib") -let library = try! device.makeLibrary(URL: libraryURL) - -var name: String - switch dataType { - case .half: name = "hgemm" - case .float: name = "sgemm" - default: fatalError() - } -let function = try! library.makeFunction( - name: name, constantValues: constants) - -let A_block_length = M_group * K_simd -let B_block_length = K_simd * N_group - -var blockElements = A_block_length + B_block_length; -if (M % 8 != 0) && (N % 8 != 0) { - let C_block_length = M_group * N_group; - blockElements = max(C_block_length, blockElements) -} -if fused_bias { - if D_trans { - blockElements = max(blockElements, M_group) - } else { - blockElements = max(blockElements, N_group) - } -} -// let blockBytes = blockElements * UInt16(dataType.size) -let elementSize = 4 -let blockBytes = blockElements * UInt16(elementSize) - -func ceilDivide(target: Int, granularity: UInt16) -> Int { - (target + Int(granularity) - 1) / Int(granularity) -} -var gridSize = MTLSize( - width: ceilDivide(target: N, granularity: N_group), - height: ceilDivide(target: M, granularity: M_group), - depth: 1) -let groupSize = MTLSize( - width: Int(32 * M_splits * N_splits), - height: 1, - depth: 1) - -let commandQueue = device.makeCommandQueue()! - -let threadgroupMemoryLength = blockBytes; - -let rowsA = M; -let columnsA = K; -let rowsB = K; -let columnsB = N; -let rowsC = M; -let columnsC = N; -var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) - -var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) - -var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) -var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC) -for i in 0...stride, options: [])! - -let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout.stride, options: [])! - -let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! -let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! - - -let pipeline = try device.makeComputePipelineState(function: function) - -func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){ - let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! - encoder.setComputePipelineState(pipeline) - encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) - - encoder.setBuffer(bufferA, offset: 0, index: 0) - encoder.setBuffer(bufferB, offset: 0, index: 1) - encoder.setBuffer(bufferC, offset: 0, index: 2) - let gridZ: Int = B - if batched{ - func byteStride(shape: [Int]) -> Int { - let rank = shape.count - var output = elementSize * shape[rank - 2] * shape[rank - 1] - if shape.dropLast(2).reduce(1, *) == 1 { - output = 0 - } - return output - } - let byteStrideA = M*K*elementSize - let byteStrideB = N*K*elementSize - let byteStrideC = M*N*elementSize - - let byteStrideD = 0 - withUnsafeTemporaryAllocation( - of: SIMD4.self, capacity: gridZ - ) { buffer in - for i in 0..>.stride - assert(MemoryLayout>.stride == 8 * 4) - encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) - } - } - gridSize.depth = gridZ - - - encoder.dispatchThreadgroups( - gridSize, threadsPerThreadgroup: groupSize - ) - encoder.endEncoding() -} - -var commandBuffer = commandQueue.makeCommandBuffer()! -call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC) -commandBuffer.commit() -commandBuffer = commandQueue.makeCommandBuffer()! -commandBuffer.encodeWaitForEvent(event, value: 2) -call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD) -commandBuffer.commit() - -commandBuffer.waitUntilCompleted() -var contents = bufferC.contents(); -var count = B * rowsA * columnsB; -var typedPointer = contents.bindMemory(to: Float.self, capacity: count) -var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) -print("First matmul is OK", Array(bufferedPointer)) - -contents = bufferD.contents(); -count = B * rowsA * columnsB; -typedPointer = contents.bindMemory(to: Float.self, capacity: count) -bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) -print("This should be filled", Array(bufferedPointer)) From aa040150985e78079bcc05df86266e447c23b4fc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 12:23:28 +0100 Subject: [PATCH 21/32] Remove `unwrap()`. --- candle-core/src/metal_backend.rs | 121 +++++++++++++++++++------------ candle-nn/src/ops.rs | 4 +- 2 files changed, 77 insertions(+), 48 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b4a490cd..f570d2c5 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,7 +8,26 @@ use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, TryLockError}; + +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl From> for MetalError { + fn from(value: TryLockError) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -24,6 +43,8 @@ pub enum MetalError { rhs_stride: Vec, mnk: (usize, usize, usize), }, + #[error("{0:?}")] + LockError(LockError), } impl From for MetalError { @@ -106,10 +127,13 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> CommandBuffer { - let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); + pub fn command_buffer(&self) -> Result { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self.command_buffer_index.try_write().unwrap(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; if *index > self.compute_per_buffer { command_buffer.commit(); command_buffer = self.command_queue.new_command_buffer().to_owned(); @@ -117,11 +141,11 @@ impl MetalDevice { *index = 0; } *index += 1; - command_buffer + Ok(command_buffer) } - pub fn wait_until_completed(&self) { - let mut command_buffer = self.command_buffer.try_write().unwrap(); + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -133,6 +157,7 @@ impl MetalDevice { command_buffer.commit(); command_buffer.wait_until_completed(); *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) } pub fn kernels(&self) -> &Kernels { @@ -148,7 +173,12 @@ impl MetalDevice { /// This means the buffer data cannot be read on the CPU directly. /// /// [`name`] is only used to keep track of the resource origin in case of bugs - pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } @@ -158,7 +188,7 @@ impl MetalDevice { /// This means the buffer can be read on the CPU but will require manual /// synchronization when the CPU memory is modified /// Used as a bridge to gather data back from the GPU - pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } @@ -168,7 +198,7 @@ impl MetalDevice { /// This method will block the computation because of the /// lack of lifetime management through the GPU. /// Internal comment for technical details. - pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { let size = core::mem::size_of_val(data) as NSUInteger; let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, @@ -179,8 +209,8 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModePrivate, "with_data", - ); - let command_buffer = self.command_buffer(); + )?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -196,8 +226,8 @@ impl MetalDevice { // Putting this wait forces the GPU buffer to be filled // with the actual data allowing the CPU storage todo // deallocate properly. - self.wait_until_completed(); - real + self.wait_until_completed()?; + Ok(real) } /// The critical allocator algorithm @@ -206,13 +236,13 @@ impl MetalDevice { size: NSUInteger, option: MTLResourceOptions, _name: &str, - ) -> Arc { - let mut buffers = self.buffers.try_write().unwrap(); + ) -> Result> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + return Ok(sub.clone()); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); @@ -226,8 +256,7 @@ impl MetalDevice { .collect(); *subbuffers = newbuffers; } - - new_buffer + Ok(new_buffer) } /// Create a metal GPU capture trace on [`path`]. @@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage { self.dtype ); } - let buffer = self.device.new_buffer_managed(self.buffer.length()); + let buffer = self.device.new_buffer_managed(self.buffer.length())?; { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); @@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage { blit.update_fence(&self.device.fence); blit.end_encoding(); } - self.device.wait_until_completed(); + self.device.wait_until_completed()?; match self.dtype { DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), @@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "affine"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "affine_f32", @@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "powf"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "powf_f32", @@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "elu"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "elu_f32", @@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype, "reduce"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, "todtype"); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", @@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && &B::KERNEL[..1] != "b" @@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype, "where"); - let command_buffer = self.device.command_buffer(); + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); } @@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype, "index_select"); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", @@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage { } }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage { } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); @@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); - let command_buffer = self.command_buffer(); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F16(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), - }; + }?; Ok(Self::Storage::new( buffer.into(), self.clone(), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index ca23f90e..94380f12 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { ) -> Result<(candle::MetalStorage, Shape)> { use candle::{backend::BackendStorage, DType}; let device = storage.device(); - let command_buffer = device.command_buffer(); + let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match storage.dtype() { DType::F32 => "softmax_f32", @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, From 6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 13:06:04 +0100 Subject: [PATCH 22/32] Addressing a lot of comments. --- candle-core/src/metal_backend.rs | 23 +++++++++++++++-------- candle-metal-kernels/src/lib.rs | 6 +++++- candle-metal-kernels/src/tests.rs | 21 +++++++++++---------- candle-nn/src/ops.rs | 3 ++- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f570d2c5..424b29d9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -482,11 +482,14 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - if !(sum_dims.len() == 1 - && sum_dims[0] == layout.shape().rank() - 1 - && layout.stride()[sum_dims[0]] == 1) - { - crate::bail!("Non last dim reduce op not supported yet"); + if sum_dims.len() != 1 { + crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet."); + } + if sum_dims[0] != layout.shape().rank() - 1 { + crate::bail!("Non last dim reduce op {op:?} not implemented yet"); + } + if layout.stride()[sum_dims[0]] != 1 { + crate::bail!("Non contiguous reduce op {op:?} not implemented yet"); } let device = self.device.clone(); @@ -524,7 +527,7 @@ impl BackendStorage for MetalStorage { } let dtype = if return_index { DType::U32 } else { self.dtype }; if dtype == DType::U32 { - crate::bail!("Implement return index reduce op"); + crate::bail!("reduce op {name} is not implemented yet."); } let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; @@ -790,12 +793,16 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(el, dtype, "where")?; let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { - crate::bail!("Invalid ternary different dtypes for values"); + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", - (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( &device.device, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a23aa47c..f2db171e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -597,6 +597,7 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, + input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -604,7 +605,10 @@ pub fn call_last_softmax( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 75c2f013..9c9475a2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -312,7 +312,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { &device, command_buffer, &kernels, - "affine_float", + "affine_f32", size, &input, &output, @@ -346,7 +346,7 @@ fn run_affine_strided( &device, command_buffer, &kernels, - "affine_float_strided", + "affine_f32_strided", shape, &input, strides, @@ -608,6 +608,7 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, + 0, &output, ) .unwrap(); @@ -622,7 +623,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -631,7 +632,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } @@ -639,7 +640,7 @@ fn reduce_sum2() { fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -651,7 +652,7 @@ fn softmax() { for i in 0..n { v[i * last_dim] = 20.0; } - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); println!("{results:?}"); assert_eq!( @@ -665,7 +666,7 @@ fn softmax() { let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -673,7 +674,7 @@ fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] @@ -684,7 +685,7 @@ fn softmax() { .map(|v| f16::from_f32(*v)) .collect::>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_half"); + let results = run_softmax(&v, last_dim, "softmax_f16"); assert_eq!( approx_f16(results, 4), vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] @@ -695,7 +696,7 @@ fn softmax() { .map(|v| bf16::from_f32(*v)) .collect::>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_bfloat"); + let results = run_softmax(&v, last_dim, "softmax_bf16"); assert_eq!( approx_bf16(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 94380f12..816eff42 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } @@ -235,6 +235,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { elem_count, last_dim, storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), &mut output, ) .unwrap(); From 972903021c50bc3b1534a3436c6828bfcc157e6e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 19:07:00 +0100 Subject: [PATCH 23/32] Finish reduce kernels. --- candle-core/src/metal_backend.rs | 54 +++++---- candle-core/tests/tensor_tests.rs | 1 + candle-metal-kernels/src/binary.metal | 7 ++ candle-metal-kernels/src/lib.rs | 60 +++++++++- candle-metal-kernels/src/reduce.metal | 163 ++++++++++++++++++++++++-- candle-metal-kernels/src/tests.rs | 12 +- 6 files changed, 258 insertions(+), 39 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 424b29d9..047313d1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - if sum_dims.len() != 1 { - crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet."); - } - if sum_dims[0] != layout.shape().rank() - 1 { - crate::bail!("Non last dim reduce op {op:?} not implemented yet"); - } - if layout.stride()[sum_dims[0]] != 1 { - crate::bail!("Non contiguous reduce op {op:?} not implemented yet"); - } - let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); - let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; @@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), - _ => crate::bail!("Reduce op for non float"), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - if dtype == DType::U32 { - crate::bail!("reduce op {name} is not implemented yet."); - } let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; - candle_metal_kernels::call_reduce_contiguous( + candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, &device.kernels, name, - src_el, + &dims, + &stride, dst_el, &self.buffer, layout.start_offset() * self.dtype.size_in_bytes(), @@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage { ("sub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), }; candle_metal_kernels::call_binary_contiguous( &device.device, @@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("bminimum", DType::F32) => strided::min::FLOAT, + ("bmaximum", DType::F32) => strided::max::FLOAT, ("badd", DType::F16) => strided::add::HALF, ("bsub", DType::F16) => strided::sub::HALF, ("bmul", DType::F16) => strided::mul::HALF, ("bdiv", DType::F16) => strided::div::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + ("bminimum", DType::F16) => strided::min::HALF, + ("bmaximum", DType::F16) => strided::max::HALF, + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), }; candle_metal_kernels::call_binary_strided( &device.device, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..06891748 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> { let t1 = tensor.reshape((190, 5, 4))?; let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; for tensor in [t1, t2] { + println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?); assert_eq!( tensor .argmax_keepdim(0)? diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index ea21bb34..f13589c1 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -1,5 +1,8 @@ #include +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -63,10 +66,14 @@ BINARY_OP(x + y, add) BINARY_OP(x - y, sub) BINARY_OP(x * y, mul) BINARY_OP(x / y, div) +BINARY_OP(MIN(x, y), min) +BINARY_OP(MAX(x, y), max) #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x / y, div) +BFLOAT_BINARY_OP(MIN(x, y), min) +BFLOAT_BINARY_OP(MAX(x, y), max) #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f2db171e..c34e34fe 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -166,7 +166,7 @@ pub mod unary { ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div); + ops!(add, sub, mul, div, min, max); } #[derive(thiserror::Error, Debug)] @@ -588,6 +588,64 @@ pub fn call_reduce_contiguous( Ok(()) } +pub fn call_reduce_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + strides: &[usize], + out_length: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let length: usize = shape.iter().product(); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + shape.len(), + shape, + strides, + elements_to_sum, + (input, input_offset), + output + ) + ); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 62443660..2d584917 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -2,6 +2,7 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) METAL_FUNC uint get_strided_index( uint idx, @@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 2048; -# define REDUCE(FN, NAME, T) \ + +#define ARGMIN(NAME, T, MAXVALUE) \ kernel void NAME( \ - constant size_t &src_numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + bool notset = true; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || src[strided_i] < shared_memory[tid]) { \ + shared_memory[tid] = src[strided_i]; \ + /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + + +#define ARGMAX(NAME, T, MINVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MINVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + bool notset = true; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || shared_memory[tid] < src[strided_i]) { \ + shared_memory[tid] = src[strided_i]; \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + +#define REDUCE(FN, NAME, T, START) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ constant size_t &el_to_sum_per_block, \ device const T *src, \ device T *dst, \ @@ -34,21 +156,21 @@ kernel void NAME( \ \ threadgroup T shared_memory[THREADGROUP_SIZE]; \ \ - shared_memory[tid] = 0; \ + shared_memory[tid] = START; \ /* \ // Elements summed in this block range from dst_id * el_to_sum_per_block \ // to (dst_id + 1) * el_to_sum_per_block. \ */ \ size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ size_t idx = start_idx + tid; \ while (idx < stop_idx) { \ /* \ // TODO: Fast version for the contiguous case. \ - // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ T x = shared_memory[tid]; \ - T y = src[idx]; \ + T y = src[strided_i]; \ shared_memory[tid] = FN; \ idx += block_dim; \ } \ @@ -71,10 +193,6 @@ kernel void NAME( \ } \ -REDUCE(x + y, fast_sum_f32, float) -REDUCE(x * y, fast_mul_f32, float) -REDUCE(max(x, y), fast_max_f32, float) - #define SOFTMAX(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -142,8 +260,33 @@ kernel void NAME( } \ } \ +REDUCE(x + y, fast_sum_f32_strided, float, 0) +REDUCE(x + y, fast_sum_u32_strided, uint, 0) +REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x * y, fast_mul_f32_strided, float, 1) +REDUCE(x * y, fast_mul_u32_strided, uint, 1) +REDUCE(x * y, fast_mul_f16_strided, half, 1) +REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) +REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) +ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32_strided, uint, 0) + SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 310 +REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 9c9475a2..8d5a2624 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -574,12 +574,16 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); - call_reduce_contiguous( + let num_dims = 1; + let dims = vec![v.len()]; + let strides = vec![1]; + call_reduce_strided( &device, command_buffer, &kernels, name, - v.len(), + &dims, + &strides, out_length, &input, 0, @@ -623,7 +627,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_f32"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -632,7 +636,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_f32"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } From 0a6e0a8c9ae056684c0edd888acb3c691f889d33 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 19:09:08 +0100 Subject: [PATCH 24/32] Implement randn (CPU-> device) --- candle-core/src/device.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 3eb7f8b7..1e33021b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -201,10 +201,9 @@ impl Device { Ok(Storage::Cuda(storage)) } } - Device::Metal(_device) => { - // let storage = device.rand_uniform(shape, dtype, lo, up)?; - // Ok(Storage::Metal(storage)) - crate::bail!("Metal rand_uniform not implemented") + Device::Metal(device) => { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Metal(storage)) } } } From e4b0cc59f5651bb4370598a902e43cd8b0af5976 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 22:32:25 +0100 Subject: [PATCH 25/32] Adding CMP --- candle-core/src/metal_backend.rs | 188 ++++++++++++++++---------- candle-metal-kernels/src/binary.metal | 35 +++-- candle-metal-kernels/src/lib.rs | 2 +- 3 files changed, 140 insertions(+), 85 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 047313d1..6f82b0cc 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage { Ok(Self::new(buffer, device, dtype)) } - fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { - crate::bail!("cmp metal") + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + let name = match op { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Le => "le", + CmpOp::Ge => "ge", + CmpOp::Lt => "lt", + CmpOp::Gt => "gt", + }; + self.binary(name, rhs, lhs_l, rhs_l) } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { @@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let device = self.device(); - let dtype = self.dtype; - let shape = lhs_l.shape(); - let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; - let command_buffer = device.command_buffer()?; - if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - && &B::KERNEL[..1] != "b" - { - use candle_metal_kernels::binary::contiguous; - - let kernel_name = match (B::KERNEL, dtype) { - ("add", DType::F32) => contiguous::add::FLOAT, - ("sub", DType::F32) => contiguous::sub::FLOAT, - ("mul", DType::F32) => contiguous::mul::FLOAT, - ("div", DType::F32) => contiguous::div::FLOAT, - ("add", DType::F16) => contiguous::add::HALF, - ("sub", DType::F16) => contiguous::sub::HALF, - ("mul", DType::F16) => contiguous::mul::HALF, - ("div", DType::F16) => contiguous::div::HALF, - (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), - }; - candle_metal_kernels::call_binary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - &self.buffer, - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } else { - use candle_metal_kernels::binary::strided; - - let kernel_name = match (B::KERNEL, dtype) { - ("badd", DType::F32) => strided::add::FLOAT, - ("bsub", DType::F32) => strided::sub::FLOAT, - ("bmul", DType::F32) => strided::mul::FLOAT, - ("bdiv", DType::F32) => strided::div::FLOAT, - ("bminimum", DType::F32) => strided::min::FLOAT, - ("bmaximum", DType::F32) => strided::max::FLOAT, - ("badd", DType::F16) => strided::add::HALF, - ("bsub", DType::F16) => strided::sub::HALF, - ("bmul", DType::F16) => strided::mul::HALF, - ("bdiv", DType::F16) => strided::div::HALF, - ("bminimum", DType::F16) => strided::min::HALF, - ("bmaximum", DType::F16) => strided::max::HALF, - (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), - }; - candle_metal_kernels::call_binary_strided( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - lhs_l.dims(), - &self.buffer, - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &buffer, - ) - .map_err(MetalError::from)?; - } - command_buffer.set_label("binary"); - Ok(Self::new(buffer, device.clone(), dtype)) + self.binary(B::KERNEL, rhs, lhs_l, rhs_l) } fn where_cond( @@ -1043,6 +982,111 @@ impl MetalStorage { pub fn buffer(&self) -> &Buffer { &self.buffer } + + pub fn binary( + &self, + op: &'static str, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &op[..1] != "b" + { + use candle_metal_kernels::binary::contiguous; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), + ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), + ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), + ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), + ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), + ("le", DType::F16) => (contiguous::le::HALF, DType::U8), + ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), + ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), + ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + } else { + use candle_metal_kernels::binary::strided; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("badd", DType::F32) => (strided::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype), + ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8), + ("le", DType::F32) => (strided::le::FLOAT, DType::U8), + ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), + ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), + ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (strided::div::HALF, self.dtype), + ("bminimum", DType::F16) => (strided::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype), + ("eq", DType::F16) => (strided::eq::HALF, DType::U8), + ("ne", DType::F16) => (strided::ne::HALF, DType::U8), + ("le", DType::F16) => (strided::le::HALF, DType::U8), + ("lt", DType::F16) => (strided::lt::HALF, DType::U8), + ("ge", DType::F16) => (strided::ge::HALF, DType::U8), + ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &rhs.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + }; + command_buffer.set_label("binary"); + Ok(Self::new(buffer, device.clone(), dtype)) + } } impl BackendDevice for MetalDevice { diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f13589c1..8c3b4a8c 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -25,15 +25,15 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[thread_position_in_grid]; \ - TYPENAME y = right[thread_position_in_grid]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[tid]; \ + TYPENAME y = right[tid]; \ + output[tid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -43,15 +43,15 @@ kernel void FN_NAME_STRIDED( \ constant size_t *right_strides, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ + output[tid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ @@ -61,6 +61,10 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); +#define BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); + BINARY_OP(x + y, add) BINARY_OP(x - y, sub) @@ -69,6 +73,13 @@ BINARY_OP(x / y, div) BINARY_OP(MIN(x, y), min) BINARY_OP(MAX(x, y), max) +BINARY_OP_OUT(eq, x == y) +BINARY_OP_OUT(ne, x != y) +BINARY_OP_OUT(le, x <= y) +BINARY_OP_OUT(lt, x < y) +BINARY_OP_OUT(ge, x >= y) +BINARY_OP_OUT(gt, x > y) + #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c34e34fe..7485ba72 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -166,7 +166,7 @@ pub mod unary { ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div, min, max); + ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)] From 586b6f6fff01f02cf5275f9ede47a0fe10206210 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 23:34:12 +0100 Subject: [PATCH 26/32] Adding gather op. --- candle-core/src/metal_backend.rs | 34 +++++++++- candle-metal-kernels/src/indexing.metal | 90 ++++++++++++++++++++----- candle-metal-kernels/src/lib.rs | 50 ++++++++++++++ 3 files changed, 157 insertions(+), 17 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6f82b0cc..227bcfb0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage { crate::bail!("upsample_nearest2d metal") } - fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { - crate::bail!("gather metal") + fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.dims()[dim]; + let dst_el = ids_l.shape().elem_count(); + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "gather_u32_f32", + (DType::U32, DType::F16) => "gather_u32_f16", + (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"), + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_gather( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + &ids.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dtype)) } fn scatter_add( diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 312b27c7..96adb4c4 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,34 @@ #include using namespace metal; +template +METAL_FUNC void index( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t id_i = (tid / right_size) % ids_size; + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + output[tid] = input[src_i]; +} + # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ kernel void NAME( \ constant size_t &dst_size, \ @@ -11,22 +39,52 @@ kernel void NAME( \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ - uint gid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (gid >= dst_size) { \ - return; \ - } \ - const size_t id_i = (gid / right_size) % ids_size; \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid / right_size / ids_size; \ - /* \ - // Force prevent out of bounds indexing \ - // since there doesn't seem to be a good way to force crash \ - // No need to check for zero we're only allowing unsized. \ - */ \ - const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ - output[gid] = input[src_i]; \ + index(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ +} + + +template +METAL_FUNC void gather( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const INDEX_TYPENAME input_i = input_ids[tid]; + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; +} + +# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } @@ -76,6 +134,8 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) +GATHER_OP(gather_u32_f32, uint, float) +GATHER_OP(gather_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7485ba72..45929aa3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1010,6 +1010,56 @@ pub fn call_index_select( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_gather( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + ids: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + input, + ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[derive(Debug, PartialEq)] pub enum Value { USize(usize), From 6a3ca7da0cfb06e80d5c75ee98a1291843092e06 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 10:32:22 +0100 Subject: [PATCH 27/32] Scatter add. --- candle-core/src/metal_backend.rs | 60 ++++++++++++++++++++----- candle-metal-kernels/src/indexing.metal | 46 ++++++++++++++++--- candle-metal-kernels/src/lib.rs | 58 +++++++++++++++++++++++- 3 files changed, 147 insertions(+), 17 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 227bcfb0..b26477fc 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -45,6 +45,12 @@ pub enum MetalError { }, #[error("{0:?}")] LockError(LockError), + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, } impl From for MetalError { @@ -827,12 +833,10 @@ impl BackendStorage for MetalStorage { } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { - let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + let (ids_o1, _) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, }; - let left_size: usize = src_l.dims()[..dim].iter().product(); - let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.dims()[dim]; let dst_el = ids_l.shape().elem_count(); let dtype = self.dtype; @@ -853,7 +857,9 @@ impl BackendStorage for MetalStorage { ids_el, dim, &self.buffer, + src_l.start_offset() * dtype.size_in_bytes(), &ids.buffer, + ids_o1 * ids.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; @@ -862,14 +868,48 @@ impl BackendStorage for MetalStorage { fn scatter_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - crate::bail!("scatter_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "sa_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_scatter_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 96adb4c4..72a3a348 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -63,11 +63,6 @@ METAL_FUNC void gather( const INDEX_TYPENAME input_i = input_ids[tid]; const size_t right_rank_i = tid % right_size; const size_t left_rank_i = tid / right_size / ids_size; - /* - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - */ const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; output[tid] = input[src_i]; } @@ -87,6 +82,45 @@ kernel void NAME( \ gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } +template +METAL_FUNC void scatter_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const INDEX_TYPENAME idx = input_ids[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } +} + +# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + scatter_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ +} template @@ -136,6 +170,8 @@ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) +SCATTER_ADD_OP(sa_u32_f32, uint, float) +SCATTER_ADD_OP(sa_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 45929aa3..ddc04d05 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1020,7 +1020,9 @@ pub fn call_gather( ids_size: usize, dim: usize, input: &Buffer, + input_offset: usize, ids: &Buffer, + ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1043,8 +1045,60 @@ pub fn call_gather( src_dim_size, right_size, ids_size, - input, - ids, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_scatter_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + (input, input_offset), + (ids, ids_offset), output ) ); From 8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 10:46:01 +0100 Subject: [PATCH 28/32] Index add. --- candle-core/src/metal_backend.rs | 49 +++++++++-- candle-metal-kernels/src/indexing.metal | 111 ++++++++++++------------ candle-metal-kernels/src/lib.rs | 54 ++++++++++++ 3 files changed, 151 insertions(+), 63 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b26477fc..21a8967b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage { fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - crate::bail!("index_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "ia_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_index_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + ids_l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn matmul( &self, diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 72a3a348..63357428 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -122,48 +122,47 @@ kernel void NAME( \ scatter_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ } - -template -void index_add( - device I *ids [[buffer(0)]], - device T *inp [[buffer(1)]], - device T *out [[buffer(2)]], - - constant uint &ids_dim_size, - constant uint &left_size, - constant uint &dst_dim_size, - constant uint &right_size, - - uint gid [[ thread_position_in_grid ]] \ -) { - - if (gid >= left_size * right_size) { - return; - } - - const uint i = gid; - const uint pre = i / right_size; - const uint post = i % right_size; - - for (uint j = 0; j < ids_dim_size; j++) { - const uint idx = ids[j]; - const uint src_i = (pre * ids_dim_size + j) * right_size + post; - const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; +template +METAL_FUNC void index_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + constant size_t &ids_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const INDEX_TYPENAME idx = input_ids[j]; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } } -#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - device INDEX_TYPENAME *ids [[buffer(0)]], \ - device TYPENAME *inp [[buffer(1)]], \ - device TYPENAME *out [[buffer(2)]], \ - constant uint &ids_dim_size, \ - constant uint &left_size, \ - constant uint &dst_dim_size, \ - constant uint &right_size, \ - uint gid [[ thread_position_in_grid ]] \ -) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ +# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + constant size_t &ids_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + index_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \ +} INDEX_OP(is_u32_f32, uint, float) @@ -175,25 +174,25 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 -IA_OP(bfloat, int64_t, ia_i64_bf16) -IA_OP(bfloat, uint32_t, ia_u32_bf16) -IA_OP(bfloat, uint8_t, ia_u8_bf16) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) #endif -IA_OP(half, uint32_t, ia_u32_f16) -IA_OP(half, uint8_t, ia_u8_f16) +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) -IA_OP(float, int64_t, ia_i64_f32) -IA_OP(uint8_t, int64_t, ia_i64_u8) -IA_OP(int64_t, int64_t, ia_i64_i64) -IA_OP(uint32_t, int64_t, ia_i64_u32) +INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) +INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) -IA_OP(float, uint32_t, ia_u32_f32) -IA_OP(uint8_t, uint32_t, ia_u32_u8) -IA_OP(int64_t, uint32_t, ia_u32_i64) -IA_OP(uint32_t, uint32_t, ia_u32_u32) +INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) +INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) -IA_OP(float, uint8_t, ia_u8_f32) -IA_OP(uint8_t, uint8_t, ia_u8_u8) -IA_OP(uint32_t, uint8_t, ia_u8_u32) -IA_OP(int64_t, uint8_t, ia_u8_i64) +INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ddc04d05..0bd7d8cb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1114,6 +1114,60 @@ pub fn call_scatter_add( Ok(()) } +pub fn call_index_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + ids_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + let ids_dim_size = ids_shape[0]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + ids_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[derive(Debug, PartialEq)] pub enum Value { USize(usize), From e8ee253ee0766c33ac69f08bb0bcd6601f47ca6f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 11:01:18 +0100 Subject: [PATCH 29/32] Missing cast. --- candle-core/src/metal_backend.rs | 2 ++ candle-metal-kernels/src/cast.metal | 1 + 2 files changed, 3 insertions(+) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 21a8967b..0af11a3d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -578,6 +578,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::F32) => "cast_u8_f32", (DType::F32, DType::F16) => "cast_f32_f16", (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), @@ -598,6 +599,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::U8, DType::F32) => "cast_u8_f32_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 4398e9d4..8481389d 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -48,6 +48,7 @@ kernel void FN_NAME_STRIDED( \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f32_f16, cast_f32_f16_strided, float, half) From 064ba17bd7de0e3c6f18f93cfe14db97e7ebca0b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 11:04:16 +0100 Subject: [PATCH 30/32] Remove print. --- candle-core/tests/tensor_tests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 06891748..c871dc96 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -543,7 +543,6 @@ fn argmax(device: &Device) -> Result<()> { let t1 = tensor.reshape((190, 5, 4))?; let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; for tensor in [t1, t2] { - println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?); assert_eq!( tensor .argmax_keepdim(0)? From 03641293eeb1dd0ff3d5a93e85c7f9eb289704e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 15:22:43 +0100 Subject: [PATCH 31/32] Clippy pass. --- candle-core/src/metal_backend.rs | 18 ++++++++---------- candle-metal-kernels/src/tests.rs | 1 - candle-nn/src/ops.rs | 6 +++--- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0af11a3d..27b2824f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -59,6 +59,8 @@ impl From for MetalError { } } +type AllocatedBuffers = Arc>>>>; + #[derive(Clone)] pub struct MetalDevice { /// Raw metal device: @@ -103,7 +105,7 @@ pub struct MetalDevice { /// /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// (strong_count = 1). - buffers: Arc>>>>, + buffers: AllocatedBuffers, } impl std::fmt::Debug for MetalDevice { @@ -258,7 +260,7 @@ impl MetalDevice { let newbuffers = subbuffers .iter() .filter(|s| Arc::strong_count(s) > 1) - .map(|s| Arc::clone(s)) + .map(Arc::clone) .collect(); *subbuffers = newbuffers; } @@ -270,7 +272,7 @@ impl MetalDevice { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - descriptor.set_capture_device(&self); + descriptor.set_capture_device(self); descriptor.set_output_url(path); capture @@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage { &self.device.kernels, name, (b, m, n, k), - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, @@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }?; - Ok(Self::Storage::new( - buffer.into(), - self.clone(), - storage.dtype(), - )) + Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) } fn rand_uniform( diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8d5a2624..1b3153b1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -574,7 +574,6 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); - let num_dims = 1; let dims = vec![v.len()]; let strides = vec![1]; call_reduce_strided( diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 816eff42..abe33350 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,17 +226,17 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, - &kernels, + kernels, name, elem_count, last_dim, storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), - &mut output, + &output, ) .unwrap(); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); From 9b5e4843a63180a2803b1e836b4ca90f14281d03 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 20 Dec 2023 09:54:19 +0100 Subject: [PATCH 32/32] Optimizing decode matmul (Phi at 28tok/s on M3). Adding some benchmark in order to help checking out matmul performance. --- Cargo.toml | 1 + candle-core/Cargo.toml | 7 ++++++ candle-core/benches/matmul.rs | 43 +++++++++++++++++++++++++++++++++ candle-metal-kernels/src/lib.rs | 20 +++++++++++---- 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 candle-core/benches/matmul.rs diff --git a/Cargo.toml b/Cargo.toml index 7c2e3a7d..9fda5fba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } +criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index e7d3ab6a..0f8c1a9f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -34,6 +34,8 @@ zip = { workspace = true } [dev-dependencies] anyhow = { workspace = true } clap = { workspace = true } +criterion = { workspace = true } + [features] default = [] @@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] + +[[bench]] +name = "matmul" +harness = false + diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs new file mode 100644 index 00000000..8732f451 --- /dev/null +++ b/candle-core/benches/matmul.rs @@ -0,0 +1,43 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group("matmul_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0bd7d8cb..0418c96c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1297,11 +1297,21 @@ pub fn call_gemm( let batched = b > 1; let fused_activation = false; let fused_bias = false; - let m_simd = 16; - let n_simd = 16; - let k_simd = 16; - let m_splits = 2; - let n_splits = 2; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 16; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 8; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; let constants = Some(ConstantValues::new(vec![ (0, Value::USize(m)), (1, Value::USize(n)),