mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Tmp for allocator.
This commit is contained in:
@ -61,8 +61,10 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
dispatch = "0.2.0"
|
dispatch = "0.2.0"
|
||||||
|
rustc-hash = "1.1"
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -31,6 +31,7 @@ thiserror = { workspace = true }
|
|||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
dispatch = { workspace = true, optional = true }
|
dispatch = { workspace = true, optional = true }
|
||||||
|
rustc-hash = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
@ -8,6 +8,8 @@ use half::f16;
|
|||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
use dispatch::{Queue, QueueAttribute};
|
use dispatch::{Queue, QueueAttribute};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
@ -37,6 +39,7 @@ pub struct MetalDevice {
|
|||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
|
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
|
||||||
queue : Queue,
|
queue : Queue,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
}
|
}
|
||||||
@ -86,7 +89,23 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = element_count * dtype.size_in_bytes();
|
||||||
|
let mut buffers = self.buffers.try_write().unwrap();
|
||||||
|
let subbuffers = buffers.entry(size).or_insert(vec![]);
|
||||||
|
|
||||||
|
for sub in &mut *subbuffers{
|
||||||
|
if sub.retain_count() == 1{
|
||||||
|
return sub.clone();
|
||||||
|
// println!("{size } {:?}", sub.retain_count());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let new_buffer = self.device
|
||||||
|
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
new_buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
@ -124,29 +143,39 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||||
|
{
|
||||||
|
let command = self.device.command_buffer();
|
||||||
|
let blit = command.new_blit_command_encoder();
|
||||||
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
|
blit.end_encoding();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize),
|
buffer.read_to_vec(buffer.length() as usize),
|
||||||
)),
|
)),
|
||||||
DType::U32 => Ok(CpuStorage::U32(
|
DType::U32 => Ok(CpuStorage::U32(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
buffer.read_to_vec(buffer.length() as usize / 4),
|
||||||
)),
|
)),
|
||||||
DType::I64 => Ok(CpuStorage::I64(
|
DType::I64 => Ok(CpuStorage::I64(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
buffer.read_to_vec(buffer.length() as usize / 8),
|
||||||
)),
|
)),
|
||||||
DType::F16 => Ok(CpuStorage::F16(
|
DType::F16 => Ok(CpuStorage::F16(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
buffer.read_to_vec(buffer.length() as usize / 2),
|
||||||
)),
|
)),
|
||||||
DType::BF16 => Ok(CpuStorage::BF16(
|
DType::BF16 => Ok(CpuStorage::BF16(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
buffer.read_to_vec(buffer.length() as usize / 2),
|
||||||
)),
|
)),
|
||||||
DType::F32 => Ok(CpuStorage::F32(
|
DType::F32 => Ok(CpuStorage::F32(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
buffer.read_to_vec(buffer.length() as usize / 4),
|
||||||
)),
|
)),
|
||||||
DType::F64 => Ok(CpuStorage::F64(
|
DType::F64 => Ok(CpuStorage::F64(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
buffer.read_to_vec(buffer.length() as usize / 8),
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -344,7 +373,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let lstride = layout.stride().to_owned();
|
let lstride = layout.stride().to_owned();
|
||||||
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
self.device.queue.exec_async(move || {
|
// self.device.queue.exec_async(move || {
|
||||||
let device = metal;
|
let device = metal;
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
@ -388,9 +417,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut cloned,
|
&mut cloned,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
// });
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
self.device.queue.exec_async(move || {
|
// self.device.queue.exec_async(move || {
|
||||||
let device = metal;
|
let device = metal;
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::strided;
|
use candle_metal_kernels::unary::strided;
|
||||||
@ -436,7 +466,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
// });
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -915,10 +945,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
||||||
|
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
|
buffers,
|
||||||
queue,
|
queue,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
|
@ -10,7 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
Reference in New Issue
Block a user