More cleanup.

This commit is contained in:
Nicolas Patry
2023-12-15 01:44:22 +01:00
parent 40c3e1bd5a
commit cf27868b57

View File

@ -173,6 +173,12 @@ pub enum MetalKernelError {
FailedToCreateComputeFunction, FailedToCreateComputeFunction,
#[error("Failed to create pipeline")] #[error("Failed to create pipeline")]
FailedToCreatePipeline(String), FailedToCreatePipeline(String),
#[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
} }
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -1029,24 +1035,22 @@ pub fn call_gemm(
} else if lhs_m1 == m && lhs_m2 == 1 { } else if lhs_m1 == m && lhs_m2 == 1 {
true true
} else { } else {
todo!(); return Err(MetalKernelError::MatMulNonContiguous {
// Err(MetalError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(),
// lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k),
// mnk: (m, n, k), })?;
// })?
}; };
let b_trans = if rhs_m1 == 1 && rhs_m2 == n { let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
false false
} else if rhs_m1 == k && rhs_m2 == 1 { } else if rhs_m1 == k && rhs_m2 == 1 {
true true
} else { } else {
todo!(); return Err(MetalKernelError::MatMulNonContiguous {
// Err(MetalError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(),
// lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k),
// mnk: (m, n, k), })?;
// })?
}; };
let d_trans = false; let d_trans = false;
let alpha = 1.0f32; let alpha = 1.0f32;
@ -1083,7 +1087,6 @@ pub fn call_gemm(
(211, Value::U16(n_splits)), (211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)), (50_001, Value::Bool(fused_bias)),
])); ]));
// println!("Constants {constants:?}");
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits; let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits; let n_group = n_simd * n_splits;
@ -1103,7 +1106,6 @@ pub fn call_gemm(
block_elements = std::cmp::max(block_elements, n_group); block_elements = std::cmp::max(block_elements, n_group);
} }
} }
// TODO adapt for f16
let bytes = match name { let bytes = match name {
"sgemm" => 4, "sgemm" => 4,
"hgemm" => 2, "hgemm" => 2,
@ -1118,7 +1120,6 @@ pub fn call_gemm(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
// println!("Threadgroup {block_bytes}");
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);