mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adding a bunch of docs !
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
@ -34,12 +34,48 @@ impl From<String> for MetalError {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
||||
/// execution order to be linear.
|
||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||
/// compute graph.
|
||||
fence: metal::Fence,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
}
|
||||
|
||||
@ -71,13 +107,13 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> CommandBuffer {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let mut command_buffer = command_buffers[0].to_owned();
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().unwrap();
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||
if *index > 20 {
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffers = vec![command_buffer.clone()];
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
}
|
||||
*index += 1;
|
||||
@ -85,8 +121,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let command_buffer = &command_buffers[0];
|
||||
let mut command_buffer = self.command_buffer.try_write().unwrap();
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -97,7 +132,7 @@ impl MetalDevice {
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -108,12 +143,65 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
fn _new_buffer(
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// This method will block the computation because of the
|
||||
/// lack of lifetime management through the GPU.
|
||||
/// Internal comment for technical details.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self.allocate_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
);
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed();
|
||||
real
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
@ -142,42 +230,7 @@ impl MetalDevice {
|
||||
new_buffer
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
);
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed();
|
||||
real
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
@ -194,8 +247,11 @@ impl MetalDevice {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
/// The actual buffer containing the data.
|
||||
buffer: Arc<metal::Buffer>,
|
||||
/// a reference to the device owning this buffer
|
||||
device: MetalDevice,
|
||||
/// The dtype is kept since buffers are untyped.
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
@ -952,29 +1008,25 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 1;
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let command_buffers = (0..n)
|
||||
.map(|i| {
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer
|
||||
})
|
||||
.collect();
|
||||
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 20,
|
||||
};
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffers,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
})
|
||||
|
Reference in New Issue
Block a user