mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Working with merging encoders and using fences.
This commit is contained in:
@ -38,6 +38,7 @@ pub struct MetalDevice {
|
|||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
|
||||||
command_buffer_index: Arc<RwLock<usize>>,
|
command_buffer_index: Arc<RwLock<usize>>,
|
||||||
|
fence: metal::Fence,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||||
}
|
}
|
||||||
@ -71,68 +72,32 @@ impl MetalDevice {
|
|||||||
|
|
||||||
pub fn command_buffer(&self) -> CommandBuffer {
|
pub fn command_buffer(&self) -> CommandBuffer {
|
||||||
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||||
|
let mut command_buffer = command_buffers[0].to_owned();
|
||||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||||
let n = command_buffers.len();
|
if *index > 20 {
|
||||||
if *index == n {
|
command_buffer.commit();
|
||||||
// todo!("Cycle buffers");
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
for i in 0..n {
|
*command_buffers = vec![command_buffer.clone()];
|
||||||
let command_buffer = &command_buffers[i];
|
|
||||||
match command_buffer.status() {
|
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled => {
|
|
||||||
// println!("Wait during cycling {i}");
|
|
||||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
}
|
|
||||||
metal::MTLCommandBufferStatus::Completed => {}
|
|
||||||
_ => {
|
|
||||||
panic!("Command buffer {i} not committed during cycling");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let new_buffers = (0..n)
|
|
||||||
.map(|i| {
|
|
||||||
// println!("Creating command buffer {i}");
|
|
||||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
command_buffer.set_label(&format!("num {i}"));
|
|
||||||
command_buffer.enqueue();
|
|
||||||
command_buffer
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
*command_buffers = new_buffers;
|
|
||||||
*index = 0;
|
*index = 0;
|
||||||
// println!("Reset");
|
|
||||||
}
|
}
|
||||||
// println!("Giving buffer {} / {n}", *index);
|
|
||||||
let out = &command_buffers[*index];
|
|
||||||
assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued);
|
|
||||||
*index += 1;
|
*index += 1;
|
||||||
out.to_owned()
|
command_buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) {
|
pub fn wait_until_completed(&self) {
|
||||||
let command_buffers = self.command_buffers.try_write().unwrap();
|
let mut command_buffers = self.command_buffers.try_write().unwrap();
|
||||||
let index = self.command_buffer_index.try_write().unwrap();
|
let command_buffer = &command_buffers[0];
|
||||||
// let n = command_buffers.len();
|
match command_buffer.status() {
|
||||||
// for i in 0..*index {
|
metal::MTLCommandBufferStatus::Committed
|
||||||
// let command_buffer = &command_buffers[i];
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
// println!("Command {i} / {n}: {:?}", command_buffer.status());
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
// }
|
panic!("Alredy committed");
|
||||||
for i in 0..*index {
|
|
||||||
let command_buffer = &command_buffers[i];
|
|
||||||
match command_buffer.status() {
|
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled => {}
|
|
||||||
metal::MTLCommandBufferStatus::Completed => {}
|
|
||||||
_ => {
|
|
||||||
panic!("Command buffer not committed");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// println!("Wait {i}");
|
_ => {}
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
// println!("Ok {i}");
|
|
||||||
// command_buffer.wait_until_completed();
|
|
||||||
}
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
*command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
@ -176,7 +141,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||||
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
|
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||||
@ -184,7 +149,7 @@ impl MetalDevice {
|
|||||||
let tmp = self.device.new_buffer_with_data(
|
let tmp = self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
data.as_ptr() as *const core::ffi::c_void,
|
||||||
size,
|
size,
|
||||||
metal::MTLResourceOptions::StorageModeShared,
|
metal::MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let real = self._new_buffer(
|
let real = self._new_buffer(
|
||||||
size,
|
size,
|
||||||
@ -194,15 +159,15 @@ impl MetalDevice {
|
|||||||
let command_buffer = self.command_buffer();
|
let command_buffer = self.command_buffer();
|
||||||
command_buffer.set_label("with_data");
|
command_buffer.set_label("with_data");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.wait_for_fence(&self.fence);
|
||||||
blit.set_label("with_data_blit");
|
blit.set_label("with_data_blit");
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||||
|
blit.update_fence(&self.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
command_buffer.commit();
|
// drop(command_buffer);
|
||||||
drop(command_buffer);
|
|
||||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||||
// println!("Command {:?}", command.status());
|
// println!("Command {:?}", command.status());
|
||||||
|
|
||||||
// self.commit();
|
|
||||||
// This is necessary, for mmaped safetensors
|
// This is necessary, for mmaped safetensors
|
||||||
// Because of the unsafe slice cast we're doing.
|
// Because of the unsafe slice cast we're doing.
|
||||||
// The slice might not live long enough for metal
|
// The slice might not live long enough for metal
|
||||||
@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
self.dtype
|
self.dtype
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed();
|
|
||||||
self.buffer
|
|
||||||
.did_modify_range(metal::NSRange::new(0, self.buffer.length()));
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||||
{
|
{
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
|
blit.wait_for_fence(&self.device.fence);
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
|
blit.update_fence(&self.device.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
|
|
||||||
command_buffer.commit();
|
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
|
|
||||||
@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
|
|
||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self::new(buffer, device, dtype))
|
||||||
}
|
}
|
||||||
@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||||
device.wait_until_completed();
|
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.set_label("to_dtype");
|
command_buffer.set_label("to_dtype");
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
device.wait_until_completed();
|
|
||||||
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.set_label("binary");
|
command_buffer.set_label("binary");
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self::new(buffer, device, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// Create kernel
|
// Create kernel
|
||||||
command_buffer.commit();
|
|
||||||
self.device.wait_until_completed();
|
|
||||||
|
|
||||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||||
}
|
}
|
||||||
@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.set_label("copy_strided");
|
command_buffer.set_label("copy_strided");
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
// println!("CREATING DEVICE");
|
// println!("CREATING DEVICE");
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
let n = 64;
|
let n = 1;
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
|
|
||||||
let command_buffers = (0..n)
|
let command_buffers = (0..n)
|
||||||
@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
.collect();
|
.collect();
|
||||||
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
let command_buffers = Arc::new(RwLock::new(command_buffers));
|
||||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let fence = device.new_fence();
|
||||||
|
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
|
fence,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffers,
|
command_buffers,
|
||||||
command_buffer_index,
|
command_buffer_index,
|
||||||
@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
command_buffer.commit();
|
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> {
|
|||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
|
let d = a.matmul(&c)?;
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
let data = vec![1.0f32, 2.0];
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
@ -184,19 +184,21 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|||||||
type Libraries = HashMap<Source, Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug)]
|
||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
pipelines: RwLock<Pipelines>,
|
pipelines: RwLock<Pipelines>,
|
||||||
|
fence: metal::Fence,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
pub fn new() -> Self {
|
pub fn new(fence: metal::Fence) -> Self {
|
||||||
let libraries = RwLock::new(Libraries::new());
|
let libraries = RwLock::new(Libraries::new());
|
||||||
let pipelines = RwLock::new(Pipelines::new());
|
let pipelines = RwLock::new(Pipelines::new());
|
||||||
Self {
|
Self {
|
||||||
libraries,
|
libraries,
|
||||||
pipelines,
|
pipelines,
|
||||||
|
fence,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,12 +306,14 @@ pub fn call_unary_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, input, output));
|
set_params!(encoder, (length, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -331,6 +335,7 @@ pub fn call_unary_strided(
|
|||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -350,6 +355,7 @@ pub fn call_unary_strided(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -368,6 +374,7 @@ pub fn call_binary_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, left, right, output));
|
set_params!(encoder, (length, left, right, output));
|
||||||
@ -375,6 +382,7 @@ pub fn call_binary_contiguous(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -399,6 +407,7 @@ pub fn call_binary_strided(
|
|||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -420,6 +429,7 @@ pub fn call_binary_strided(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -438,12 +448,14 @@ pub fn call_cast_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, (input, input_offset), output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -463,6 +475,7 @@ pub fn call_cast_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -482,6 +495,7 @@ pub fn call_cast_strided(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -501,6 +515,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -527,6 +542,7 @@ pub fn call_reduce_contiguous(
|
|||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -544,6 +560,7 @@ pub fn call_last_softmax(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||||
@ -569,6 +586,7 @@ pub fn call_last_softmax(
|
|||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -588,12 +606,14 @@ pub fn call_affine(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, add, input, output));
|
set_params!(encoder, (size, mul, add, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -616,6 +636,7 @@ pub fn call_affine_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -634,6 +655,7 @@ pub fn call_affine_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -652,12 +674,14 @@ pub fn call_powf(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -679,6 +703,7 @@ pub fn call_powf_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -696,6 +721,7 @@ pub fn call_powf_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -714,12 +740,14 @@ pub fn call_elu(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -741,6 +769,7 @@ pub fn call_elu_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -758,6 +787,7 @@ pub fn call_elu_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -779,6 +809,7 @@ pub fn call_where_cond_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
@ -803,6 +834,7 @@ pub fn call_where_cond_strided(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -829,6 +861,7 @@ pub fn call_index_select(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -848,6 +881,7 @@ pub fn call_index_select(
|
|||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1045,6 +1079,7 @@ pub fn call_gemm(
|
|||||||
let block_bytes = block_elements * bytes;
|
let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
// println!("Threadgroup {block_bytes}");
|
// println!("Threadgroup {block_bytes}");
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
@ -1087,6 +1122,7 @@ pub fn call_gemm(
|
|||||||
};
|
};
|
||||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
209
candle-metal-kernels/src/test.swift
Normal file
209
candle-metal-kernels/src/test.swift
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
|
||||||
|
import Metal
|
||||||
|
import MetalPerformanceShadersGraph
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
let type = MTLDataType.float;
|
||||||
|
let dataType = type;
|
||||||
|
var B = 2;
|
||||||
|
var M = 2;
|
||||||
|
var N = 2;
|
||||||
|
var K = 2;
|
||||||
|
var A_trans = false;
|
||||||
|
var B_trans = false;
|
||||||
|
var D_trans = false;
|
||||||
|
var alpha = Float(1.0);
|
||||||
|
var beta = Float(0.0);
|
||||||
|
var batched = B > 1;
|
||||||
|
var fused_activation = false;
|
||||||
|
var fused_bias = false;
|
||||||
|
let constants = MTLFunctionConstantValues()
|
||||||
|
constants.setConstantValue(&M, type: .uint, index: 0)
|
||||||
|
constants.setConstantValue(&N, type: .uint, index: 1)
|
||||||
|
constants.setConstantValue(&K, type: .uint, index: 2)
|
||||||
|
constants.setConstantValue(&A_trans, type: .bool, index: 10)
|
||||||
|
constants.setConstantValue(&B_trans, type: .bool, index: 11)
|
||||||
|
constants.setConstantValue(&D_trans, type: .bool, index: 13)
|
||||||
|
constants.setConstantValue(&alpha, type: .float, index: 20)
|
||||||
|
constants.setConstantValue(&beta, type: .float, index: 21)
|
||||||
|
constants.setConstantValue(&batched, type: .bool, index: 100)
|
||||||
|
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
|
||||||
|
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
|
||||||
|
|
||||||
|
|
||||||
|
var M_simd = UInt16(16)
|
||||||
|
var N_simd = UInt16(16)
|
||||||
|
var K_simd = UInt16(32)
|
||||||
|
var M_splits = UInt16(2)
|
||||||
|
var N_splits = UInt16(2)
|
||||||
|
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
|
||||||
|
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
|
||||||
|
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
|
||||||
|
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
|
||||||
|
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
|
||||||
|
|
||||||
|
let M_group = M_simd * M_splits
|
||||||
|
let N_group = N_simd * N_splits
|
||||||
|
|
||||||
|
// Satisfy Metal API validation.
|
||||||
|
#if DEBUG
|
||||||
|
do {
|
||||||
|
var garbage: SIMD4<UInt64> = .zero
|
||||||
|
constants.setConstantValue(&garbage, type: .bool, index: 102)
|
||||||
|
constants.setConstantValue(&garbage, type: .bool, index: 103)
|
||||||
|
constants.setConstantValue(&garbage, type: .bool, index: 113)
|
||||||
|
constants.setConstantValue(&garbage, type: .bool, index: 50000)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
let device = MTLCopyAllDevices().first!
|
||||||
|
device.shouldMaximizeConcurrentCompilation = true
|
||||||
|
|
||||||
|
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
|
||||||
|
libraryURL.append(component: "src")
|
||||||
|
libraryURL.append(component: "libMetalFlashAttention.metallib")
|
||||||
|
let library = try! device.makeLibrary(URL: libraryURL)
|
||||||
|
|
||||||
|
var name: String
|
||||||
|
switch dataType {
|
||||||
|
case .half: name = "hgemm"
|
||||||
|
case .float: name = "sgemm"
|
||||||
|
default: fatalError()
|
||||||
|
}
|
||||||
|
let function = try! library.makeFunction(
|
||||||
|
name: name, constantValues: constants)
|
||||||
|
|
||||||
|
let A_block_length = M_group * K_simd
|
||||||
|
let B_block_length = K_simd * N_group
|
||||||
|
|
||||||
|
var blockElements = A_block_length + B_block_length;
|
||||||
|
if (M % 8 != 0) && (N % 8 != 0) {
|
||||||
|
let C_block_length = M_group * N_group;
|
||||||
|
blockElements = max(C_block_length, blockElements)
|
||||||
|
}
|
||||||
|
if fused_bias {
|
||||||
|
if D_trans {
|
||||||
|
blockElements = max(blockElements, M_group)
|
||||||
|
} else {
|
||||||
|
blockElements = max(blockElements, N_group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// let blockBytes = blockElements * UInt16(dataType.size)
|
||||||
|
let elementSize = 4
|
||||||
|
let blockBytes = blockElements * UInt16(elementSize)
|
||||||
|
|
||||||
|
func ceilDivide(target: Int, granularity: UInt16) -> Int {
|
||||||
|
(target + Int(granularity) - 1) / Int(granularity)
|
||||||
|
}
|
||||||
|
var gridSize = MTLSize(
|
||||||
|
width: ceilDivide(target: N, granularity: N_group),
|
||||||
|
height: ceilDivide(target: M, granularity: M_group),
|
||||||
|
depth: 1)
|
||||||
|
let groupSize = MTLSize(
|
||||||
|
width: Int(32 * M_splits * N_splits),
|
||||||
|
height: 1,
|
||||||
|
depth: 1)
|
||||||
|
|
||||||
|
let commandQueue = device.makeCommandQueue()!
|
||||||
|
|
||||||
|
let threadgroupMemoryLength = blockBytes;
|
||||||
|
|
||||||
|
let rowsA = M;
|
||||||
|
let columnsA = K;
|
||||||
|
let rowsB = K;
|
||||||
|
let columnsB = N;
|
||||||
|
let rowsC = M;
|
||||||
|
let columnsC = N;
|
||||||
|
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
|
||||||
|
|
||||||
|
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
|
||||||
|
|
||||||
|
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||||
|
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||||
|
for i in 0..<arrayA.count {
|
||||||
|
arrayA[i] = Float(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..<arrayB.count {
|
||||||
|
arrayB[i] = Float(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
|
||||||
|
|
||||||
|
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
|
||||||
|
|
||||||
|
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||||
|
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
|
||||||
|
|
||||||
|
|
||||||
|
let pipeline = try device.makeComputePipelineState(function: function)
|
||||||
|
|
||||||
|
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
|
||||||
|
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
|
||||||
|
encoder.setComputePipelineState(pipeline)
|
||||||
|
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
|
||||||
|
|
||||||
|
encoder.setBuffer(bufferA, offset: 0, index: 0)
|
||||||
|
encoder.setBuffer(bufferB, offset: 0, index: 1)
|
||||||
|
encoder.setBuffer(bufferC, offset: 0, index: 2)
|
||||||
|
let gridZ: Int = B
|
||||||
|
if batched{
|
||||||
|
func byteStride(shape: [Int]) -> Int {
|
||||||
|
let rank = shape.count
|
||||||
|
var output = elementSize * shape[rank - 2] * shape[rank - 1]
|
||||||
|
if shape.dropLast(2).reduce(1, *) == 1 {
|
||||||
|
output = 0
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
let byteStrideA = M*K*elementSize
|
||||||
|
let byteStrideB = N*K*elementSize
|
||||||
|
let byteStrideC = M*N*elementSize
|
||||||
|
|
||||||
|
let byteStrideD = 0
|
||||||
|
withUnsafeTemporaryAllocation(
|
||||||
|
of: SIMD4<UInt64>.self, capacity: gridZ
|
||||||
|
) { buffer in
|
||||||
|
for i in 0..<buffer.count {
|
||||||
|
buffer[i] = SIMD4(
|
||||||
|
UInt64(truncatingIfNeeded: i * byteStrideA),
|
||||||
|
UInt64(truncatingIfNeeded: i * byteStrideB),
|
||||||
|
UInt64(truncatingIfNeeded: i * byteStrideC),
|
||||||
|
UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||||
|
}
|
||||||
|
|
||||||
|
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
|
||||||
|
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
|
||||||
|
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gridSize.depth = gridZ
|
||||||
|
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(
|
||||||
|
gridSize, threadsPerThreadgroup: groupSize
|
||||||
|
)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
|
var commandBuffer = commandQueue.makeCommandBuffer()!
|
||||||
|
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
|
||||||
|
commandBuffer.commit()
|
||||||
|
commandBuffer = commandQueue.makeCommandBuffer()!
|
||||||
|
commandBuffer.encodeWaitForEvent(event, value: 2)
|
||||||
|
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
|
||||||
|
commandBuffer.commit()
|
||||||
|
|
||||||
|
commandBuffer.waitUntilCompleted()
|
||||||
|
var contents = bufferC.contents();
|
||||||
|
var count = B * rowsA * columnsB;
|
||||||
|
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||||
|
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||||
|
print("First matmul is OK", Array(bufferedPointer))
|
||||||
|
|
||||||
|
contents = bufferD.contents();
|
||||||
|
count = B * rowsA * columnsB;
|
||||||
|
typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||||
|
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||||
|
print("This should be filled", Array(bufferedPointer))
|
@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
&mut output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
|
||||||
output.did_modify_range(metal::NSRange::new(0, output.length()));
|
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||||
Ok((newstorage, layout.shape().clone()))
|
Ok((newstorage, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user