Working with merging encoders and using fences.

This commit is contained in:
Nicolas Patry
2023-12-14 16:05:33 +01:00
parent 931432ed55
commit 361f2ad2af
5 changed files with 279 additions and 94 deletions

View File

@ -38,6 +38,7 @@ pub struct MetalDevice {
command_queue: metal::CommandQueue,
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
command_buffer_index: Arc<RwLock<usize>>,
fence: metal::Fence,
kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
@ -71,68 +72,32 @@ impl MetalDevice {
pub fn command_buffer(&self) -> CommandBuffer {
let mut command_buffers = self.command_buffers.try_write().unwrap();
let mut command_buffer = command_buffers[0].to_owned();
let mut index = self.command_buffer_index.try_write().unwrap();
let n = command_buffers.len();
if *index == n {
// todo!("Cycle buffers");
for i in 0..n {
let command_buffer = &command_buffers[i];
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled => {
// println!("Wait during cycling {i}");
// println!("Command {i} / {n}: {:?}", command_buffer.status());
command_buffer.wait_until_completed();
}
metal::MTLCommandBufferStatus::Completed => {}
_ => {
panic!("Command buffer {i} not committed during cycling");
}
}
}
let new_buffers = (0..n)
.map(|i| {
// println!("Creating command buffer {i}");
let command_buffer = self.command_queue.new_command_buffer().to_owned();
command_buffer.set_label(&format!("num {i}"));
command_buffer.enqueue();
command_buffer
})
.collect();
*command_buffers = new_buffers;
if *index > 20 {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffers = vec![command_buffer.clone()];
*index = 0;
// println!("Reset");
}
// println!("Giving buffer {} / {n}", *index);
let out = &command_buffers[*index];
assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued);
*index += 1;
out.to_owned()
command_buffer
}
pub fn wait_until_completed(&self) {
let command_buffers = self.command_buffers.try_write().unwrap();
let index = self.command_buffer_index.try_write().unwrap();
// let n = command_buffers.len();
// for i in 0..*index {
// let command_buffer = &command_buffers[i];
// println!("Command {i} / {n}: {:?}", command_buffer.status());
// }
for i in 0..*index {
let command_buffer = &command_buffers[i];
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled => {}
metal::MTLCommandBufferStatus::Completed => {}
_ => {
panic!("Command buffer not committed");
}
let mut command_buffers = self.command_buffers.try_write().unwrap();
let command_buffer = &command_buffers[0];
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Alredy committed");
}
// println!("Wait {i}");
command_buffer.wait_until_completed();
// println!("Ok {i}");
// command_buffer.wait_until_completed();
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
}
pub fn kernels(&self) -> &Kernels {
@ -176,7 +141,7 @@ impl MetalDevice {
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
@ -184,7 +149,7 @@ impl MetalDevice {
let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
size,
metal::MTLResourceOptions::StorageModeShared,
metal::MTLResourceOptions::StorageModeManaged,
);
let real = self._new_buffer(
size,
@ -194,15 +159,15 @@ impl MetalDevice {
let command_buffer = self.command_buffer();
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding();
command_buffer.commit();
drop(command_buffer);
// drop(command_buffer);
// real.did_modify_range(metal::NSRange::new(0, real.length()));
// println!("Command {:?}", command.status());
// self.commit();
// This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing.
// The slice might not live long enough for metal
@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage {
self.dtype
);
}
self.device.wait_until_completed();
self.buffer
.did_modify_range(metal::NSRange::new(0, self.buffer.length()));
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command_buffer = self.device.command_buffer();
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
blit.end_encoding();
command_buffer.commit();
}
self.device.wait_until_completed();
@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device, dtype))
}
@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype");
device.wait_until_completed();
let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
}
command_buffer.set_label("to_dtype");
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
device.wait_until_completed();
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
}
command_buffer.set_label("binary");
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device, dtype))
}
@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
// Create kernel
command_buffer.commit();
self.device.wait_until_completed();
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
}
@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
}
command_buffer.commit();
Ok(())
}
}
@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice {
// println!("CREATING DEVICE");
let device = metal::Device::all().swap_remove(ordinal);
let n = 64;
let n = 1;
let command_queue = device.new_command_queue();
let command_buffers = (0..n)
@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice {
.collect();
let command_buffers = Arc::new(RwLock::new(command_buffers));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self {
device,
fence,
command_queue,
command_buffers,
command_buffer_index,
@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice {
0,
);
blit.end_encoding();
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}