Tmp for allocator.

This commit is contained in:
Nicolas Patry
2023-11-16 12:50:41 +01:00
parent 181d2299b2
commit 7e49e0af96
4 changed files with 50 additions and 14 deletions

View File

@ -61,8 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } #metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0" dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -31,6 +31,7 @@ thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
dispatch = { workspace = true, optional = true } dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -8,6 +8,8 @@ use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute}; use dispatch::{Queue, QueueAttribute};
/// Metal related errors /// Metal related errors
@ -37,6 +39,7 @@ pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue, queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -86,7 +89,23 @@ impl MetalDevice {
} }
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = element_count * dtype.size_in_bytes();
let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{
if sub.retain_count() == 1{
return sub.clone();
// println!("{size } {:?}", sub.retain_count());
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
subbuffers.push(new_buffer.clone());
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self.device self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged) .new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
@ -124,29 +143,39 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed(); self.device.wait_until_completed();
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize), buffer.read_to_vec(buffer.length() as usize),
)), )),
DType::U32 => Ok(CpuStorage::U32( DType::U32 => Ok(CpuStorage::U32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4), buffer.read_to_vec(buffer.length() as usize / 4),
)), )),
DType::I64 => Ok(CpuStorage::I64( DType::I64 => Ok(CpuStorage::I64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8), buffer.read_to_vec(buffer.length() as usize / 8),
)), )),
DType::F16 => Ok(CpuStorage::F16( DType::F16 => Ok(CpuStorage::F16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2), buffer.read_to_vec(buffer.length() as usize / 2),
)), )),
DType::BF16 => Ok(CpuStorage::BF16( DType::BF16 => Ok(CpuStorage::BF16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2), buffer.read_to_vec(buffer.length() as usize / 2),
)), )),
DType::F32 => Ok(CpuStorage::F32( DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4), buffer.read_to_vec(buffer.length() as usize / 4),
)), )),
DType::F64 => Ok(CpuStorage::F64( DType::F64 => Ok(CpuStorage::F64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8), buffer.read_to_vec(buffer.length() as usize / 8),
)), )),
} }
} }
@ -344,7 +373,7 @@ impl BackendStorage for MetalStorage {
let lstride = layout.stride().to_owned(); let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes(); let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
self.device.queue.exec_async(move || { // self.device.queue.exec_async(move || {
let device = metal; let device = metal;
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
@ -388,9 +417,10 @@ impl BackendStorage for MetalStorage {
&mut cloned, &mut cloned,
) )
.unwrap(); .unwrap();
}); // });
} else { } else {
self.device.queue.exec_async(move || { // self.device.queue.exec_async(move || {
let device = metal; let device = metal;
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided; use candle_metal_kernels::unary::strided;
@ -436,7 +466,7 @@ impl BackendStorage for MetalStorage {
0, 0,
) )
.unwrap(); .unwrap();
}); // });
} }
Ok(Self { Ok(Self {
@ -915,10 +945,12 @@ impl BackendDevice for MetalDevice {
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial); let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
command_buffer, command_buffer,
buffers,
queue, queue,
kernels, kernels,
}) })

View File

@ -10,7 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } # metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0" once_cell = "1.18.0"
thiserror = "1" thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"