mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Working with merging encoders and using fences.
This commit is contained in:
@ -38,6 +38,7 @@ pub struct MetalDevice {
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
fence: metal::Fence,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
}
|
||||
@ -71,68 +72,32 @@ impl MetalDevice {
|
||||
|
||||
pub fn command_buffer(&self) -> CommandBuffer {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let mut command_buffer = command_buffers[0].to_owned();
|
||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||
let n = command_buffers.len();
|
||||
if *index == n {
|
||||
// todo!("Cycle buffers");
|
||||
for i in 0..n {
|
||||
let command_buffer = &command_buffers[i];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled => {
|
||||
// println!("Wait during cycling {i}");
|
||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
metal::MTLCommandBufferStatus::Completed => {}
|
||||
_ => {
|
||||
panic!("Command buffer {i} not committed during cycling");
|
||||
}
|
||||
}
|
||||
}
|
||||
let new_buffers = (0..n)
|
||||
.map(|i| {
|
||||
// println!("Creating command buffer {i}");
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer.enqueue();
|
||||
command_buffer
|
||||
})
|
||||
.collect();
|
||||
*command_buffers = new_buffers;
|
||||
if *index > 20 {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffers = vec![command_buffer.clone()];
|
||||
*index = 0;
|
||||
// println!("Reset");
|
||||
}
|
||||
// println!("Giving buffer {} / {n}", *index);
|
||||
let out = &command_buffers[*index];
|
||||
assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued);
|
||||
*index += 1;
|
||||
out.to_owned()
|
||||
command_buffer
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let index = self.command_buffer_index.try_write().unwrap();
|
||||
// let n = command_buffers.len();
|
||||
// for i in 0..*index {
|
||||
// let command_buffer = &command_buffers[i];
|
||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
||||
// }
|
||||
for i in 0..*index {
|
||||
let command_buffer = &command_buffers[i];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled => {}
|
||||
metal::MTLCommandBufferStatus::Completed => {}
|
||||
_ => {
|
||||
panic!("Command buffer not committed");
|
||||
}
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let command_buffer = &command_buffers[0];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Alredy committed");
|
||||
}
|
||||
// println!("Wait {i}");
|
||||
command_buffer.wait_until_completed();
|
||||
// println!("Ok {i}");
|
||||
// command_buffer.wait_until_completed();
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -176,7 +141,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
@ -184,7 +149,7 @@ impl MetalDevice {
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeShared,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
size,
|
||||
@ -194,15 +159,15 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
command_buffer.commit();
|
||||
drop(command_buffer);
|
||||
// drop(command_buffer);
|
||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||
// println!("Command {:?}", command.status());
|
||||
|
||||
// self.commit();
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage {
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
self.buffer
|
||||
.did_modify_range(metal::NSRange::new(0, self.buffer.length()));
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.wait_for_fence(&self.device.fence);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
|
||||
@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||
device.wait_until_completed();
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("to_dtype");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// Create kernel
|
||||
command_buffer.commit();
|
||||
self.device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
command_buffer.commit();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice {
|
||||
// println!("CREATING DEVICE");
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 64;
|
||||
let n = 1;
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let command_buffers = (0..n)
|
||||
@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice {
|
||||
.collect();
|
||||
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffers,
|
||||
command_buffer_index,
|
||||
@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice {
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user