Compare commits

...

1 Commits

Author SHA1 Message Date
5edb07a5b1 mps matmul 2023-12-20 02:53:18 +01:00
3 changed files with 193 additions and 109 deletions

View File

@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.0", features = ["mps"], package = "candle-metal" }
[profile.release-with-debug]
inherits = "release"

View File

@ -10,7 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.0", features = ["mps"], package="candle-metal" }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -5,6 +5,7 @@ use metal::{
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
@ -1052,123 +1053,206 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
let d_trans = false;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
let m_simd = 16;
let n_simd = 16;
let k_simd = 16;
let m_splits = 2;
let n_splits = 2;
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
(20, Value::F32(alpha)),
(21, Value::F32(beta)),
(100, Value::Bool(batched)),
(101, Value::Bool(fused_activation)),
// Garbage
(102, Value::Bool(false)),
(103, Value::Bool(false)),
(113, Value::Bool(false)),
(50_000, Value::Bool(false)),
// End garbage
(200, Value::U16(m_simd)),
(201, Value::U16(n_simd)),
(202, Value::U16(k_simd)),
(210, Value::U16(m_splits)),
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
// let d_trans = false;
// let alpha = 1.0f32;
// let beta = 0.0f32;
// let batched = b > 1;
// let fused_activation = false;
// let fused_bias = false;
// let m_simd = 16;
// let n_simd = 16;
// let k_simd = 16;
// let m_splits = 2;
// let n_splits = 2;
// let constants = Some(ConstantValues::new(vec![
// (0, Value::USize(m)),
// (1, Value::USize(n)),
// (2, Value::USize(k)),
// (10, Value::Bool(a_trans)),
// (11, Value::Bool(b_trans)),
// (13, Value::Bool(d_trans)),
// (20, Value::F32(alpha)),
// (21, Value::F32(beta)),
// (100, Value::Bool(batched)),
// (101, Value::Bool(fused_activation)),
// // Garbage
// (102, Value::Bool(false)),
// (103, Value::Bool(false)),
// (113, Value::Bool(false)),
// (50_000, Value::Bool(false)),
// // End garbage
// (200, Value::U16(m_simd)),
// (201, Value::U16(n_simd)),
// (202, Value::U16(k_simd)),
// (210, Value::U16(m_splits)),
// (211, Value::U16(n_splits)),
// (50_001, Value::Bool(fused_bias)),
// ]));
// let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
// let m_group = m_simd * m_splits;
// let n_group = n_simd * n_splits;
//
// let a_block_length = m_group * k_simd;
// let b_block_length = k_simd * n_group;
//
// let mut block_elements = a_block_length + b_block_length;
// if (m % 8 != 0) && (n % 8 != 0) {
// let c_block_length = m_group * n_group;
// block_elements = std::cmp::max(c_block_length, block_elements)
// }
// if fused_bias {
// if d_trans {
// block_elements = std::cmp::max(block_elements, m_group);
// } else {
// block_elements = std::cmp::max(block_elements, n_group);
// }
// }
// let bytes = match name {
// "sgemm" => 4,
// "hgemm" => 2,
// other => {
// return Err(MetalKernelError::LoadLibraryError(format!(
// "{other} is not a valid kernel for gemm"
// )));
// }
// };
// let block_bytes = block_elements * bytes;
//
// let encoder = command_buffer.new_compute_command_encoder();
// encoder.wait_for_fence(&kernels.fence);
// encoder.set_compute_pipeline_state(&pipeline);
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
// encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
// encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
// encoder.set_buffer(2, Some(output), 0);
// // TODO Tensor D
//
// let grid_z = b;
// if batched {
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
// let byte_stride_c = m * n * bytes as usize;
// // TODO byte_stride_d
// let byte_stride_d = 0;
//
// let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
// for i in 0..b {
// buffer.push((i * byte_stride_a) as u64);
// buffer.push((i * byte_stride_b) as u64);
// buffer.push((i * byte_stride_c) as u64);
// buffer.push((i * byte_stride_d) as u64);
// }
// encoder.set_bytes(
// 10,
// (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
// buffer.as_ptr() as *const NSUInteger as *const c_void,
// );
// }
//
// let grid_size = MTLSize {
// width: divide(n, n_group.into()),
// height: divide(m, m_group.into()),
// depth: grid_z as NSUInteger,
// };
// let group_size = MTLSize {
// width: 32 * (m_splits as u64) * (n_splits as u64),
// height: 1,
// depth: 1,
// };
// // println!("grid size {grid_size:?} group size {group_size:?}");
// encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
// encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
// encoder.use_resource(output, metal::MTLResourceUsage::Write);
// encoder.dispatch_thread_groups(grid_size, group_size);
// encoder.update_fence(&kernels.fence);
// encoder.end_encoding();
let a_block_length = m_group * k_simd;
let b_block_length = k_simd * n_group;
let (b, m, n, k) = (
b as NSUInteger,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
);
let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)));
}
};
let block_bytes = block_elements * bytes;
let (size, data_type) = if name == "sgemm" { (4, 0x10000000 | 32) } else { (2, 0x10000000 | 16) };
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
// TODO Tensor D
let left_matrix = create_matrix(
lhs_buffer,
(b, m, k),
a_trans,
size,
lhs_offset as NSUInteger,
data_type,
).unwrap();
let grid_z = b;
if batched {
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
let right_matrix = create_matrix(
rhs_buffer,
(b, k, n),
b_trans,
size,
rhs_offset as NSUInteger,
data_type,
).unwrap();
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
for i in 0..b {
buffer.push((i * byte_stride_a) as u64);
buffer.push((i * byte_stride_b) as u64);
buffer.push((i * byte_stride_c) as u64);
buffer.push((i * byte_stride_d) as u64);
}
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
let result_matrix = create_matrix(
output,
(b, m, n),
false,
size,
0,
data_type,
).unwrap();
let grid_size = MTLSize {
width: divide(n, n_group.into()),
height: divide(m, m_group.into()),
depth: grid_z as NSUInteger,
};
let group_size = MTLSize {
width: 32 * (m_splits as u64) * (n_splits as u64),
height: 1,
depth: 1,
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&device,
a_trans,
b_trans,
m,
n,
k,
1.0,
0.0,
).unwrap();
matrix_multiplication.encode_to_command_buffer(
command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
Ok(())
}
fn create_matrix(
buffer: &Buffer,
(b, rows, columns): (NSUInteger, NSUInteger, NSUInteger),
transpose: bool,
size: NSUInteger,
offset: NSUInteger,
data_type: u32,
) -> Option<Matrix> {
let (rows, columns) = if transpose {
(columns, rows)
} else {
(rows, columns)
};
let descriptor = if b == 1 {
MatrixDescriptor::init_single(rows, columns, columns * size, data_type)
} else {
MatrixDescriptor::init_multiple(
rows,
columns,
b,
columns * size,
rows * columns * size,
data_type,
)
};
return Matrix::init_with_buffer_descriptor(&buffer, offset * size, &descriptor);
}
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}