From 7e49e0af9669a1caa362043a76b2a0f4664ea9bc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 16 Nov 2023 12:50:41 +0100 Subject: [PATCH] Tmp for allocator. --- Cargo.toml | 4 ++- candle-core/Cargo.toml | 1 + candle-core/src/metal_backend.rs | 56 +++++++++++++++++++++++++------- candle-metal-kernels/Cargo.toml | 3 +- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fd3ec2e1..c1234a42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,8 +61,10 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../metal-rs", features = ["mps"] } dispatch = "0.2.0" +rustc-hash = "1.1" [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index c5a071e2..9a13da91 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -31,6 +31,7 @@ thiserror = { workspace = true } yoke = { workspace = true } zip = { workspace = true } dispatch = { workspace = true, optional = true } +rustc-hash = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 727749e8..bc15bf79 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,6 +8,8 @@ use half::f16; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::{Arc, RwLock}; +use std::collections::HashMap; +use rustc_hash::FxHashMap; use dispatch::{Queue, QueueAttribute}; /// Metal related errors @@ -37,6 +39,7 @@ pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, command_buffer: Arc>, + buffers: Arc>>>, queue : Queue, kernels: Arc, } @@ -86,7 +89,23 @@ impl MetalDevice { } pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + let size = element_count * dtype.size_in_bytes(); + let mut buffers = self.buffers.try_write().unwrap(); + let subbuffers = buffers.entry(size).or_insert(vec![]); + + for sub in &mut *subbuffers{ + if sub.retain_count() == 1{ + return sub.clone(); + // println!("{size } {:?}", sub.retain_count()); + } + } + let new_buffer = self.device + .new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate); + subbuffers.push(new_buffer.clone()); + new_buffer + } + + pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer { self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -124,29 +143,39 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + let buffer = self.device.new_buffer_managed(self.buffer.length()); + { + let command = self.device.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + + } + + self.device.wait_until_completed(); match self.dtype { DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize), + buffer.read_to_vec(buffer.length() as usize), )), DType::U32 => Ok(CpuStorage::U32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), + buffer.read_to_vec(buffer.length() as usize / 4), )), DType::I64 => Ok(CpuStorage::I64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), + buffer.read_to_vec(buffer.length() as usize / 8), )), DType::F16 => Ok(CpuStorage::F16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), + buffer.read_to_vec(buffer.length() as usize / 2), )), DType::BF16 => Ok(CpuStorage::BF16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), + buffer.read_to_vec(buffer.length() as usize / 2), )), DType::F32 => Ok(CpuStorage::F32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), + buffer.read_to_vec(buffer.length() as usize / 4), )), DType::F64 => Ok(CpuStorage::F64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), + buffer.read_to_vec(buffer.length() as usize / 8), )), } } @@ -344,7 +373,7 @@ impl BackendStorage for MetalStorage { let lstride = layout.stride().to_owned(); let loffset = layout.start_offset() * dtype.size_in_bytes(); if layout.is_contiguous() && layout.start_offset() == 0 { - self.device.queue.exec_async(move || { + // self.device.queue.exec_async(move || { let device = metal; let command_buffer = device.command_buffer(); use candle_metal_kernels::unary::contiguous; @@ -388,9 +417,10 @@ impl BackendStorage for MetalStorage { &mut cloned, ) .unwrap(); - }); + // }); + } else { - self.device.queue.exec_async(move || { + // self.device.queue.exec_async(move || { let device = metal; let command_buffer = device.command_buffer(); use candle_metal_kernels::unary::strided; @@ -436,7 +466,7 @@ impl BackendStorage for MetalStorage { 0, ) .unwrap(); - }); + // }); } Ok(Self { @@ -915,10 +945,12 @@ impl BackendDevice for MetalDevice { let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial); + let buffers = Arc::new(RwLock::new(FxHashMap::default())); Ok(Self { device, command_queue, command_buffer, + buffers, queue, kernels, }) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2585ca62..2d2742ab 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../../metal-rs", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37"