mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Working with merging encoders and using fences.
This commit is contained in:
@ -38,6 +38,7 @@ pub struct MetalDevice {
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
fence: metal::Fence,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
}
|
||||
@ -71,68 +72,32 @@ impl MetalDevice {
|
||||
|
||||
pub fn command_buffer(&self) -> CommandBuffer {
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let mut command_buffer = command_buffers[0].to_owned();
|
||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||
let n = command_buffers.len();
|
||||
if *index == n {
|
||||
// todo!("Cycle buffers");
|
||||
for i in 0..n {
|
||||
let command_buffer = &command_buffers[i];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled => {
|
||||
// println!("Wait during cycling {i}");
|
||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
metal::MTLCommandBufferStatus::Completed => {}
|
||||
_ => {
|
||||
panic!("Command buffer {i} not committed during cycling");
|
||||
}
|
||||
}
|
||||
}
|
||||
let new_buffers = (0..n)
|
||||
.map(|i| {
|
||||
// println!("Creating command buffer {i}");
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer.enqueue();
|
||||
command_buffer
|
||||
})
|
||||
.collect();
|
||||
*command_buffers = new_buffers;
|
||||
if *index > 20 {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffers = vec![command_buffer.clone()];
|
||||
*index = 0;
|
||||
// println!("Reset");
|
||||
}
|
||||
// println!("Giving buffer {} / {n}", *index);
|
||||
let out = &command_buffers[*index];
|
||||
assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued);
|
||||
*index += 1;
|
||||
out.to_owned()
|
||||
command_buffer
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let index = self.command_buffer_index.try_write().unwrap();
|
||||
// let n = command_buffers.len();
|
||||
// for i in 0..*index {
|
||||
// let command_buffer = &command_buffers[i];
|
||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
||||
// }
|
||||
for i in 0..*index {
|
||||
let command_buffer = &command_buffers[i];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled => {}
|
||||
metal::MTLCommandBufferStatus::Completed => {}
|
||||
_ => {
|
||||
panic!("Command buffer not committed");
|
||||
}
|
||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||
let command_buffer = &command_buffers[0];
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Alredy committed");
|
||||
}
|
||||
// println!("Wait {i}");
|
||||
command_buffer.wait_until_completed();
|
||||
// println!("Ok {i}");
|
||||
// command_buffer.wait_until_completed();
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -176,7 +141,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
@ -184,7 +149,7 @@ impl MetalDevice {
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeShared,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
size,
|
||||
@ -194,15 +159,15 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
command_buffer.commit();
|
||||
drop(command_buffer);
|
||||
// drop(command_buffer);
|
||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||
// println!("Command {:?}", command.status());
|
||||
|
||||
// self.commit();
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage {
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
self.buffer
|
||||
.did_modify_range(metal::NSRange::new(0, self.buffer.length()));
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.wait_for_fence(&self.device.fence);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
|
||||
@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||
device.wait_until_completed();
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("to_dtype");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// Create kernel
|
||||
command_buffer.commit();
|
||||
self.device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
command_buffer.commit();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice {
|
||||
// println!("CREATING DEVICE");
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 64;
|
||||
let n = 1;
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let command_buffers = (0..n)
|
||||
@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice {
|
||||
.collect();
|
||||
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffers,
|
||||
command_buffer_index,
|
||||
@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice {
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
|
@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
let d = a.matmul(&c)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
|
@ -184,19 +184,21 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug)]
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
fence: metal::Fence,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(fence: metal::Fence) -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
fence,
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,12 +306,14 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -331,6 +335,7 @@ pub fn call_unary_strided(
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -350,6 +355,7 @@ pub fn call_unary_strided(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -368,6 +374,7 @@ pub fn call_binary_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
@ -375,6 +382,7 @@ pub fn call_binary_contiguous(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -399,6 +407,7 @@ pub fn call_binary_strided(
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let width: usize = shape.iter().product();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -420,6 +429,7 @@ pub fn call_binary_strided(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -438,12 +448,14 @@ pub fn call_cast_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -463,6 +475,7 @@ pub fn call_cast_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -482,6 +495,7 @@ pub fn call_cast_strided(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -501,6 +515,7 @@ pub fn call_reduce_contiguous(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -527,6 +542,7 @@ pub fn call_reduce_contiguous(
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -544,6 +560,7 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
@ -569,6 +586,7 @@ pub fn call_last_softmax(
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -588,12 +606,14 @@ pub fn call_affine(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -616,6 +636,7 @@ pub fn call_affine_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -634,6 +655,7 @@ pub fn call_affine_strided(
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -652,12 +674,14 @@ pub fn call_powf(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -679,6 +703,7 @@ pub fn call_powf_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -696,6 +721,7 @@ pub fn call_powf_strided(
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -714,12 +740,14 @@ pub fn call_elu(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -741,6 +769,7 @@ pub fn call_elu_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -758,6 +787,7 @@ pub fn call_elu_strided(
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -779,6 +809,7 @@ pub fn call_where_cond_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let size: usize = shape.iter().product();
|
||||
@ -803,6 +834,7 @@ pub fn call_where_cond_strided(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -829,6 +861,7 @@ pub fn call_index_select(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -848,6 +881,7 @@ pub fn call_index_select(
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1045,6 +1079,7 @@ pub fn call_gemm(
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
// println!("Threadgroup {block_bytes}");
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
@ -1087,6 +1122,7 @@ pub fn call_gemm(
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
|
209
candle-metal-kernels/src/test.swift
Normal file
209
candle-metal-kernels/src/test.swift
Normal file
@ -0,0 +1,209 @@
|
||||
|
||||
import Metal
|
||||
import MetalPerformanceShadersGraph
|
||||
|
||||
|
||||
|
||||
let type = MTLDataType.float;
|
||||
let dataType = type;
|
||||
var B = 2;
|
||||
var M = 2;
|
||||
var N = 2;
|
||||
var K = 2;
|
||||
var A_trans = false;
|
||||
var B_trans = false;
|
||||
var D_trans = false;
|
||||
var alpha = Float(1.0);
|
||||
var beta = Float(0.0);
|
||||
var batched = B > 1;
|
||||
var fused_activation = false;
|
||||
var fused_bias = false;
|
||||
let constants = MTLFunctionConstantValues()
|
||||
constants.setConstantValue(&M, type: .uint, index: 0)
|
||||
constants.setConstantValue(&N, type: .uint, index: 1)
|
||||
constants.setConstantValue(&K, type: .uint, index: 2)
|
||||
constants.setConstantValue(&A_trans, type: .bool, index: 10)
|
||||
constants.setConstantValue(&B_trans, type: .bool, index: 11)
|
||||
constants.setConstantValue(&D_trans, type: .bool, index: 13)
|
||||
constants.setConstantValue(&alpha, type: .float, index: 20)
|
||||
constants.setConstantValue(&beta, type: .float, index: 21)
|
||||
constants.setConstantValue(&batched, type: .bool, index: 100)
|
||||
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
|
||||
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
|
||||
|
||||
|
||||
var M_simd = UInt16(16)
|
||||
var N_simd = UInt16(16)
|
||||
var K_simd = UInt16(32)
|
||||
var M_splits = UInt16(2)
|
||||
var N_splits = UInt16(2)
|
||||
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
|
||||
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
|
||||
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
|
||||
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
|
||||
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
|
||||
|
||||
let M_group = M_simd * M_splits
|
||||
let N_group = N_simd * N_splits
|
||||
|
||||
// Satisfy Metal API validation.
|
||||
#if DEBUG
|
||||
do {
|
||||
var garbage: SIMD4<UInt64> = .zero
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 102)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 103)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 113)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 50000)
|
||||
}
|
||||
#endif
|
||||
|
||||
let device = MTLCopyAllDevices().first!
|
||||
device.shouldMaximizeConcurrentCompilation = true
|
||||
|
||||
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
|
||||
libraryURL.append(component: "src")
|
||||
libraryURL.append(component: "libMetalFlashAttention.metallib")
|
||||
let library = try! device.makeLibrary(URL: libraryURL)
|
||||
|
||||
var name: String
|
||||
switch dataType {
|
||||
case .half: name = "hgemm"
|
||||
case .float: name = "sgemm"
|
||||
default: fatalError()
|
||||
}
|
||||
let function = try! library.makeFunction(
|
||||
name: name, constantValues: constants)
|
||||
|
||||
let A_block_length = M_group * K_simd
|
||||
let B_block_length = K_simd * N_group
|
||||
|
||||
var blockElements = A_block_length + B_block_length;
|
||||
if (M % 8 != 0) && (N % 8 != 0) {
|
||||
let C_block_length = M_group * N_group;
|
||||
blockElements = max(C_block_length, blockElements)
|
||||
}
|
||||
if fused_bias {
|
||||
if D_trans {
|
||||
blockElements = max(blockElements, M_group)
|
||||
} else {
|
||||
blockElements = max(blockElements, N_group)
|
||||
}
|
||||
}
|
||||
// let blockBytes = blockElements * UInt16(dataType.size)
|
||||
let elementSize = 4
|
||||
let blockBytes = blockElements * UInt16(elementSize)
|
||||
|
||||
func ceilDivide(target: Int, granularity: UInt16) -> Int {
|
||||
(target + Int(granularity) - 1) / Int(granularity)
|
||||
}
|
||||
var gridSize = MTLSize(
|
||||
width: ceilDivide(target: N, granularity: N_group),
|
||||
height: ceilDivide(target: M, granularity: M_group),
|
||||
depth: 1)
|
||||
let groupSize = MTLSize(
|
||||
width: Int(32 * M_splits * N_splits),
|
||||
height: 1,
|
||||
depth: 1)
|
||||
|
||||
let commandQueue = device.makeCommandQueue()!
|
||||
|
||||
let threadgroupMemoryLength = blockBytes;
|
||||
|
||||
let rowsA = M;
|
||||
let columnsA = K;
|
||||
let rowsB = K;
|
||||
let columnsB = N;
|
||||
let rowsC = M;
|
||||
let columnsC = N;
|
||||
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
|
||||
|
||||
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
|
||||
|
||||
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
for i in 0..<arrayA.count {
|
||||
arrayA[i] = Float(i)
|
||||
}
|
||||
|
||||
for i in 0..<arrayB.count {
|
||||
arrayB[i] = Float(i)
|
||||
}
|
||||
|
||||
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||
|
||||
|
||||
let pipeline = try device.makeComputePipelineState(function: function)
|
||||
|
||||
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
|
||||
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
|
||||
encoder.setComputePipelineState(pipeline)
|
||||
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
|
||||
|
||||
encoder.setBuffer(bufferA, offset: 0, index: 0)
|
||||
encoder.setBuffer(bufferB, offset: 0, index: 1)
|
||||
encoder.setBuffer(bufferC, offset: 0, index: 2)
|
||||
let gridZ: Int = B
|
||||
if batched{
|
||||
func byteStride(shape: [Int]) -> Int {
|
||||
let rank = shape.count
|
||||
var output = elementSize * shape[rank - 2] * shape[rank - 1]
|
||||
if shape.dropLast(2).reduce(1, *) == 1 {
|
||||
output = 0
|
||||
}
|
||||
return output
|
||||
}
|
||||
let byteStrideA = M*K*elementSize
|
||||
let byteStrideB = N*K*elementSize
|
||||
let byteStrideC = M*N*elementSize
|
||||
|
||||
let byteStrideD = 0
|
||||
withUnsafeTemporaryAllocation(
|
||||
of: SIMD4<UInt64>.self, capacity: gridZ
|
||||
) { buffer in
|
||||
for i in 0..<buffer.count {
|
||||
buffer[i] = SIMD4(
|
||||
UInt64(truncatingIfNeeded: i * byteStrideA),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideB),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideC),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||
}
|
||||
|
||||
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
|
||||
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
|
||||
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||
}
|
||||
}
|
||||
gridSize.depth = gridZ
|
||||
|
||||
|
||||
encoder.dispatchThreadgroups(
|
||||
gridSize, threadsPerThreadgroup: groupSize
|
||||
)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
var commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
|
||||
commandBuffer.commit()
|
||||
commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
commandBuffer.encodeWaitForEvent(event, value: 2)
|
||||
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
|
||||
commandBuffer.commit()
|
||||
|
||||
commandBuffer.waitUntilCompleted()
|
||||
var contents = bufferC.contents();
|
||||
var count = B * rowsA * columnsB;
|
||||
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
print("First matmul is OK", Array(bufferedPointer))
|
||||
|
||||
contents = bufferD.contents();
|
||||
count = B * rowsA * columnsB;
|
||||
typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
print("This should be filled", Array(bufferedPointer))
|
@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
output.did_modify_range(metal::NSRange::new(0, output.length()));
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
|
Reference in New Issue
Block a user