diff --git a/Cargo.toml b/Cargo.toml index 1a8145ba..6cf99174 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,8 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../metal-rs", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6159ffc0..ae7987dd 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -67,12 +67,28 @@ impl MetalDevice { self.command_buffer.read().unwrap() } - pub fn wait_until_completed(&self) { - let mut old = self.command_buffer.write().unwrap(); + pub fn commit_wait_until_completed(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + let status = old.status(); + use metal::MTLCommandBufferStatus::{ + Committed, Completed, Enqueued, Error, NotEnqueued, Scheduled, + }; + // match old.status() {} + if old.status() == metal::MTLCommandBufferStatus::Completed { + return; + } old.commit(); old.wait_until_completed(); - let command_buffer = self.command_queue.new_owned_command_buffer(); + // let count = old.retain_count(); + // println!("Count {count:?}"); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + // let count = old.retain_count(); + // // println!("Count after {count:?}"); + // old.release(); + // let count = old.retain_count(); + // println!("Count after release {count:?}"); // self.command_buffer.replace_with(|_| command_buffer) } @@ -86,18 +102,21 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.heap + // println!("Creating buffer {size}"); + let buffer = self + .heap .new_buffer(size, MTLResourceOptions::StorageModeShared) - .expect(" New buffer") + .expect("New buffer"); + // println!("{:?}", self.heap.used_size()); + buffer } pub fn new_buffer_with_data(&self, data: &[T]) -> Buffer { + let size = core::mem::size_of_val(data) as NSUInteger; let option = metal::MTLResourceOptions::StorageModeShared; - self.device.new_buffer_with_data( - data.as_ptr() as *const core::ffi::c_void, - core::mem::size_of_val(data) as NSUInteger, - option, - ) + // println!("Creating data buffer {size}"); + self.device + .new_buffer_with_data(data.as_ptr() as *const core::ffi::c_void, size, option) } } @@ -124,7 +143,7 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - self.device.wait_until_completed(); + self.device.commit_wait_until_completed(); match self.dtype { DType::U8 => Ok(CpuStorage::U8( @@ -335,93 +354,95 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_buffer(); - if layout.is_contiguous() && layout.start_offset() == 0 { - use candle_metal_kernels::unary::contiguous; + { + let command_buffer = device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + use candle_metal_kernels::unary::contiguous; - let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => contiguous::cos::FLOAT, - ("usin", DType::F32) => contiguous::sin::FLOAT, - ("usqr", DType::F32) => contiguous::sqr::FLOAT, - ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, - ("uneg", DType::F32) => contiguous::neg::FLOAT, - ("uexp", DType::F32) => contiguous::exp::FLOAT, - ("ulog", DType::F32) => contiguous::log::FLOAT, - ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, - ("uerf", DType::F32) => contiguous::erf::FLOAT, - ("uceil", DType::F32) => contiguous::ceil::FLOAT, - ("ufloor", DType::F32) => contiguous::floor::FLOAT, - ("uround", DType::F32) => contiguous::round::FLOAT, - ("ucos", DType::F16) => contiguous::cos::HALF, - ("usin", DType::F16) => contiguous::sin::HALF, - ("usqr", DType::F16) => contiguous::sqr::HALF, - ("usqrt", DType::F16) => contiguous::sqrt::HALF, - ("uneg", DType::F16) => contiguous::neg::HALF, - ("uexp", DType::F16) => contiguous::exp::HALF, - ("ulog", DType::F16) => contiguous::log::HALF, - ("ugelu", DType::F16) => contiguous::gelu::HALF, - ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, - ("uerf", DType::F16) => contiguous::erf::HALF, - ("uceil", DType::F16) => contiguous::ceil::HALF, - ("ufloor", DType::F16) => contiguous::floor::HALF, - ("uround", DType::F16) => contiguous::round::HALF, - (name, dtype) => todo!("Match {name} - {dtype:?}"), - }; - candle_metal_kernels::call_unary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - &self.buffer, - &mut buffer, - ) - .map_err(MetalError::from)?; - } else { - use candle_metal_kernels::unary::strided; - let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => strided::cos::FLOAT, - ("usin", DType::F32) => strided::sin::FLOAT, - ("usqr", DType::F32) => strided::sqr::FLOAT, - ("usqrt", DType::F32) => strided::sqrt::FLOAT, - ("uneg", DType::F32) => strided::neg::FLOAT, - ("uexp", DType::F32) => strided::exp::FLOAT, - ("ulog", DType::F32) => strided::log::FLOAT, - ("ugelu", DType::F32) => strided::gelu::FLOAT, - ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, - ("uerf", DType::F32) => strided::erf::FLOAT, - ("uceil", DType::F32) => strided::ceil::FLOAT, - ("ufloor", DType::F32) => strided::floor::FLOAT, - ("uround", DType::F32) => strided::round::FLOAT, - ("ucos", DType::F16) => strided::cos::HALF, - ("usin", DType::F16) => strided::sin::HALF, - ("usqr", DType::F16) => strided::sqr::HALF, - ("usqrt", DType::F16) => strided::sqrt::HALF, - ("uneg", DType::F16) => strided::neg::HALF, - ("uexp", DType::F16) => strided::exp::HALF, - ("ulog", DType::F16) => strided::log::HALF, - ("ugelu", DType::F16) => strided::gelu::HALF, - ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, - ("uerf", DType::F16) => strided::erf::HALF, - ("uceil", DType::F16) => strided::ceil::HALF, - ("ufloor", DType::F16) => strided::floor::HALF, - ("uround", DType::F16) => strided::round::HALF, - (name, dtype) => todo!("Match {name} - {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - layout.dims(), - &self.buffer, - layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - &mut buffer, - 0, - ) - .map_err(MetalError::from)?; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &mut buffer, + 0, + ) + .map_err(MetalError::from)?; + } } Ok(Self { buffer, @@ -769,59 +790,61 @@ impl BackendStorage for MetalStorage { let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let command_buffer = self.device.command_buffer(); - for bi in 0..b { - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor( - &self.buffer, - (bi * stride_left + lhs_l.start_offset() as u64) * size, - &left_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor( - &rhs.buffer, - (bi * stride_right + rhs_l.start_offset() as u64) * size, - &right_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + { + let command_buffer = self.device.command_buffer(); + for bi in 0..b { + // Create matrix objects + let left_matrix = Matrix::init_with_buffer_descriptor( + &self.buffer, + (bi * stride_left + lhs_l.start_offset() as u64) * size, + &left_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + let right_matrix = Matrix::init_with_buffer_descriptor( + &rhs.buffer, + (bi * stride_right + rhs_l.start_offset() as u64) * size, + &right_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; - let result_matrix = Matrix::init_with_buffer_descriptor( - &out_buffer, - bi * m * n * size, - &result_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let result_matrix = Matrix::init_with_buffer_descriptor( + &out_buffer, + bi * m * n * size, + &result_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let alpha = 1.0f64; + let beta = 0.0f64; + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &self.device, + transpose_left, + transpose_right, + m, + n, + k, + alpha, + beta, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( - &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); + // Encode kernel to command buffer + matrix_multiplication.encode_to_command_buffer( + &command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + } } Ok(Self { @@ -891,7 +914,7 @@ impl BackendDevice for MetalDevice { descriptor.set_size(size.size); descriptor.set_storage_mode(metal::MTLStorageMode::Shared); let heap = device.new_heap(&descriptor); - let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer())); + let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); Ok(Self { device, diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2585ca62..2d2742ab 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../../metal-rs", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0227119..d6851a69 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1133,6 +1133,7 @@ mod tests { let device = Device::system_default().expect("no device found"); let options = CompileOptions::new(); + options.set_fast_math_enabled(true); let library = device.new_library_with_source(INDEXING, &options).unwrap(); let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];