Adding a bunch of docs !

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-12-15 11:02:41 +01:00
parent cf27868b57
commit 243e83f2b9
2 changed files with 128 additions and 59 deletions

View File

@ -34,12 +34,48 @@ impl From<String> for MetalError {
#[derive(Clone)]
pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
device: metal::Device,
/// Single command queue for the entire device.
command_queue: metal::CommandQueue,
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1).
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
@ -71,13 +107,13 @@ impl MetalDevice {
}
pub fn command_buffer(&self) -> CommandBuffer {
let mut command_buffers = self.command_buffers.try_write().unwrap();
let mut command_buffer = command_buffers[0].to_owned();
let mut command_buffer_lock = self.command_buffer.try_write().unwrap();
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self.command_buffer_index.try_write().unwrap();
if *index > 20 {
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffers = vec![command_buffer.clone()];
*command_buffer_lock = command_buffer.clone();
*index = 0;
}
*index += 1;
@ -85,8 +121,7 @@ impl MetalDevice {
}
pub fn wait_until_completed(&self) {
let mut command_buffers = self.command_buffers.try_write().unwrap();
let command_buffer = &command_buffers[0];
let mut command_buffer = self.command_buffer.try_write().unwrap();
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
@ -97,7 +132,7 @@ impl MetalDevice {
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
*command_buffer = self.command_queue.new_command_buffer().to_owned();
}
pub fn kernels(&self) -> &Kernels {
@ -108,12 +143,65 @@ impl MetalDevice {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
fn _new_buffer(
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// This method will block the computation because of the
/// lack of lifetime management through the GPU.
/// Internal comment for technical details.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
let size = core::mem::size_of_val(data) as NSUInteger;
let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
size,
metal::MTLResourceOptions::StorageModeManaged,
);
let real = self.allocate_buffer(
size,
metal::MTLResourceOptions::StorageModePrivate,
"with_data",
);
let command_buffer = self.command_buffer();
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding();
// This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing.
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo
// deallocate properly.
self.wait_until_completed();
real
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
@ -142,42 +230,7 @@ impl MetalDevice {
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
let size = core::mem::size_of_val(data) as NSUInteger;
let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
size,
metal::MTLResourceOptions::StorageModeManaged,
);
let real = self._new_buffer(
size,
metal::MTLResourceOptions::StorageModePrivate,
"with_data",
);
let command_buffer = self.command_buffer();
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding();
// This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing.
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo
// deallocate properly.
self.wait_until_completed();
real
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
@ -194,8 +247,11 @@ impl MetalDevice {
#[derive(Debug, Clone)]
pub struct MetalStorage {
/// The actual buffer containing the data.
buffer: Arc<metal::Buffer>,
/// a reference to the device owning this buffer
device: MetalDevice,
/// The dtype is kept since buffers are untyped.
dtype: DType,
}
@ -952,29 +1008,25 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let n = 1;
let command_queue = device.new_command_queue();
let command_buffers = (0..n)
.map(|i| {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
command_buffer.set_label(&format!("num {i}"));
command_buffer
})
.collect();
let command_buffers = Arc::new(RwLock::new(command_buffers));
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 20,
};
Ok(Self {
device,
fence,
command_queue,
command_buffers,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
kernels,
})

View File

@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
@ -220,6 +228,9 @@ impl Kernels {
Source::Mfa => panic!("Invalid lib"),
}
}
/// Load the give library from its [`source`].
/// If this has been previously loaded it will just fetch it from cache.
pub fn load_library(
&self,
device: &Device,
@ -262,6 +273,9 @@ impl Kernels {
Ok(func)
}
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source
fn load_pipeline_with_constants(
&self,
device: &Device,
@ -290,6 +304,9 @@ impl Kernels {
}
}
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source (without constants)
pub fn load_pipeline(
&self,
device: &Device,