|
|
|
@ -5,6 +5,7 @@ use metal::{
|
|
|
|
|
use std::collections::HashMap;
|
|
|
|
|
use std::ffi::c_void;
|
|
|
|
|
use std::sync::RwLock;
|
|
|
|
|
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
|
|
|
|
|
|
|
|
|
const AFFINE: &str = include_str!("affine.metal");
|
|
|
|
|
const INDEXING: &str = include_str!("indexing.metal");
|
|
|
|
@ -1052,123 +1053,206 @@ pub fn call_gemm(
|
|
|
|
|
mnk: (m, n, k),
|
|
|
|
|
})?;
|
|
|
|
|
};
|
|
|
|
|
let d_trans = false;
|
|
|
|
|
let alpha = 1.0f32;
|
|
|
|
|
let beta = 0.0f32;
|
|
|
|
|
let batched = b > 1;
|
|
|
|
|
let fused_activation = false;
|
|
|
|
|
let fused_bias = false;
|
|
|
|
|
let m_simd = 16;
|
|
|
|
|
let n_simd = 16;
|
|
|
|
|
let k_simd = 16;
|
|
|
|
|
let m_splits = 2;
|
|
|
|
|
let n_splits = 2;
|
|
|
|
|
let constants = Some(ConstantValues::new(vec![
|
|
|
|
|
(0, Value::USize(m)),
|
|
|
|
|
(1, Value::USize(n)),
|
|
|
|
|
(2, Value::USize(k)),
|
|
|
|
|
(10, Value::Bool(a_trans)),
|
|
|
|
|
(11, Value::Bool(b_trans)),
|
|
|
|
|
(13, Value::Bool(d_trans)),
|
|
|
|
|
(20, Value::F32(alpha)),
|
|
|
|
|
(21, Value::F32(beta)),
|
|
|
|
|
(100, Value::Bool(batched)),
|
|
|
|
|
(101, Value::Bool(fused_activation)),
|
|
|
|
|
// Garbage
|
|
|
|
|
(102, Value::Bool(false)),
|
|
|
|
|
(103, Value::Bool(false)),
|
|
|
|
|
(113, Value::Bool(false)),
|
|
|
|
|
(50_000, Value::Bool(false)),
|
|
|
|
|
// End garbage
|
|
|
|
|
(200, Value::U16(m_simd)),
|
|
|
|
|
(201, Value::U16(n_simd)),
|
|
|
|
|
(202, Value::U16(k_simd)),
|
|
|
|
|
(210, Value::U16(m_splits)),
|
|
|
|
|
(211, Value::U16(n_splits)),
|
|
|
|
|
(50_001, Value::Bool(fused_bias)),
|
|
|
|
|
]));
|
|
|
|
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
|
|
|
|
let m_group = m_simd * m_splits;
|
|
|
|
|
let n_group = n_simd * n_splits;
|
|
|
|
|
// let d_trans = false;
|
|
|
|
|
// let alpha = 1.0f32;
|
|
|
|
|
// let beta = 0.0f32;
|
|
|
|
|
// let batched = b > 1;
|
|
|
|
|
// let fused_activation = false;
|
|
|
|
|
// let fused_bias = false;
|
|
|
|
|
// let m_simd = 16;
|
|
|
|
|
// let n_simd = 16;
|
|
|
|
|
// let k_simd = 16;
|
|
|
|
|
// let m_splits = 2;
|
|
|
|
|
// let n_splits = 2;
|
|
|
|
|
// let constants = Some(ConstantValues::new(vec![
|
|
|
|
|
// (0, Value::USize(m)),
|
|
|
|
|
// (1, Value::USize(n)),
|
|
|
|
|
// (2, Value::USize(k)),
|
|
|
|
|
// (10, Value::Bool(a_trans)),
|
|
|
|
|
// (11, Value::Bool(b_trans)),
|
|
|
|
|
// (13, Value::Bool(d_trans)),
|
|
|
|
|
// (20, Value::F32(alpha)),
|
|
|
|
|
// (21, Value::F32(beta)),
|
|
|
|
|
// (100, Value::Bool(batched)),
|
|
|
|
|
// (101, Value::Bool(fused_activation)),
|
|
|
|
|
// // Garbage
|
|
|
|
|
// (102, Value::Bool(false)),
|
|
|
|
|
// (103, Value::Bool(false)),
|
|
|
|
|
// (113, Value::Bool(false)),
|
|
|
|
|
// (50_000, Value::Bool(false)),
|
|
|
|
|
// // End garbage
|
|
|
|
|
// (200, Value::U16(m_simd)),
|
|
|
|
|
// (201, Value::U16(n_simd)),
|
|
|
|
|
// (202, Value::U16(k_simd)),
|
|
|
|
|
// (210, Value::U16(m_splits)),
|
|
|
|
|
// (211, Value::U16(n_splits)),
|
|
|
|
|
// (50_001, Value::Bool(fused_bias)),
|
|
|
|
|
// ]));
|
|
|
|
|
// let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
|
|
|
|
// let m_group = m_simd * m_splits;
|
|
|
|
|
// let n_group = n_simd * n_splits;
|
|
|
|
|
//
|
|
|
|
|
// let a_block_length = m_group * k_simd;
|
|
|
|
|
// let b_block_length = k_simd * n_group;
|
|
|
|
|
//
|
|
|
|
|
// let mut block_elements = a_block_length + b_block_length;
|
|
|
|
|
// if (m % 8 != 0) && (n % 8 != 0) {
|
|
|
|
|
// let c_block_length = m_group * n_group;
|
|
|
|
|
// block_elements = std::cmp::max(c_block_length, block_elements)
|
|
|
|
|
// }
|
|
|
|
|
// if fused_bias {
|
|
|
|
|
// if d_trans {
|
|
|
|
|
// block_elements = std::cmp::max(block_elements, m_group);
|
|
|
|
|
// } else {
|
|
|
|
|
// block_elements = std::cmp::max(block_elements, n_group);
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
// let bytes = match name {
|
|
|
|
|
// "sgemm" => 4,
|
|
|
|
|
// "hgemm" => 2,
|
|
|
|
|
// other => {
|
|
|
|
|
// return Err(MetalKernelError::LoadLibraryError(format!(
|
|
|
|
|
// "{other} is not a valid kernel for gemm"
|
|
|
|
|
// )));
|
|
|
|
|
// }
|
|
|
|
|
// };
|
|
|
|
|
// let block_bytes = block_elements * bytes;
|
|
|
|
|
//
|
|
|
|
|
// let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
// encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
// encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
|
|
|
|
// encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
|
|
|
// encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
|
|
|
|
// encoder.set_buffer(2, Some(output), 0);
|
|
|
|
|
// // TODO Tensor D
|
|
|
|
|
//
|
|
|
|
|
// let grid_z = b;
|
|
|
|
|
// if batched {
|
|
|
|
|
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
|
|
|
|
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
|
|
|
|
// let byte_stride_c = m * n * bytes as usize;
|
|
|
|
|
// // TODO byte_stride_d
|
|
|
|
|
// let byte_stride_d = 0;
|
|
|
|
|
//
|
|
|
|
|
// let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
|
|
|
|
// for i in 0..b {
|
|
|
|
|
// buffer.push((i * byte_stride_a) as u64);
|
|
|
|
|
// buffer.push((i * byte_stride_b) as u64);
|
|
|
|
|
// buffer.push((i * byte_stride_c) as u64);
|
|
|
|
|
// buffer.push((i * byte_stride_d) as u64);
|
|
|
|
|
// }
|
|
|
|
|
// encoder.set_bytes(
|
|
|
|
|
// 10,
|
|
|
|
|
// (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
|
|
|
|
// buffer.as_ptr() as *const NSUInteger as *const c_void,
|
|
|
|
|
// );
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// let grid_size = MTLSize {
|
|
|
|
|
// width: divide(n, n_group.into()),
|
|
|
|
|
// height: divide(m, m_group.into()),
|
|
|
|
|
// depth: grid_z as NSUInteger,
|
|
|
|
|
// };
|
|
|
|
|
// let group_size = MTLSize {
|
|
|
|
|
// width: 32 * (m_splits as u64) * (n_splits as u64),
|
|
|
|
|
// height: 1,
|
|
|
|
|
// depth: 1,
|
|
|
|
|
// };
|
|
|
|
|
// // println!("grid size {grid_size:?} group size {group_size:?}");
|
|
|
|
|
// encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
|
|
|
|
// encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
|
|
|
// encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
// encoder.dispatch_thread_groups(grid_size, group_size);
|
|
|
|
|
// encoder.update_fence(&kernels.fence);
|
|
|
|
|
// encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
let a_block_length = m_group * k_simd;
|
|
|
|
|
let b_block_length = k_simd * n_group;
|
|
|
|
|
let (b, m, n, k) = (
|
|
|
|
|
b as NSUInteger,
|
|
|
|
|
m as NSUInteger,
|
|
|
|
|
n as NSUInteger,
|
|
|
|
|
k as NSUInteger,
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let mut block_elements = a_block_length + b_block_length;
|
|
|
|
|
if (m % 8 != 0) && (n % 8 != 0) {
|
|
|
|
|
let c_block_length = m_group * n_group;
|
|
|
|
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
|
|
|
|
}
|
|
|
|
|
if fused_bias {
|
|
|
|
|
if d_trans {
|
|
|
|
|
block_elements = std::cmp::max(block_elements, m_group);
|
|
|
|
|
} else {
|
|
|
|
|
block_elements = std::cmp::max(block_elements, n_group);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let bytes = match name {
|
|
|
|
|
"sgemm" => 4,
|
|
|
|
|
"hgemm" => 2,
|
|
|
|
|
other => {
|
|
|
|
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
|
|
|
|
"{other} is not a valid kernel for gemm"
|
|
|
|
|
)));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
let block_bytes = block_elements * bytes;
|
|
|
|
|
let (size, data_type) = if name == "sgemm" { (4, 0x10000000 | 32) } else { (2, 0x10000000 | 16) };
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
|
|
|
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
|
|
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
|
|
|
|
encoder.set_buffer(2, Some(output), 0);
|
|
|
|
|
// TODO Tensor D
|
|
|
|
|
let left_matrix = create_matrix(
|
|
|
|
|
lhs_buffer,
|
|
|
|
|
(b, m, k),
|
|
|
|
|
a_trans,
|
|
|
|
|
size,
|
|
|
|
|
lhs_offset as NSUInteger,
|
|
|
|
|
data_type,
|
|
|
|
|
).unwrap();
|
|
|
|
|
|
|
|
|
|
let grid_z = b;
|
|
|
|
|
if batched {
|
|
|
|
|
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
|
|
|
|
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
|
|
|
|
let byte_stride_c = m * n * bytes as usize;
|
|
|
|
|
// TODO byte_stride_d
|
|
|
|
|
let byte_stride_d = 0;
|
|
|
|
|
let right_matrix = create_matrix(
|
|
|
|
|
rhs_buffer,
|
|
|
|
|
(b, k, n),
|
|
|
|
|
b_trans,
|
|
|
|
|
size,
|
|
|
|
|
rhs_offset as NSUInteger,
|
|
|
|
|
data_type,
|
|
|
|
|
).unwrap();
|
|
|
|
|
|
|
|
|
|
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
|
|
|
|
for i in 0..b {
|
|
|
|
|
buffer.push((i * byte_stride_a) as u64);
|
|
|
|
|
buffer.push((i * byte_stride_b) as u64);
|
|
|
|
|
buffer.push((i * byte_stride_c) as u64);
|
|
|
|
|
buffer.push((i * byte_stride_d) as u64);
|
|
|
|
|
}
|
|
|
|
|
encoder.set_bytes(
|
|
|
|
|
10,
|
|
|
|
|
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
|
|
|
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
let result_matrix = create_matrix(
|
|
|
|
|
output,
|
|
|
|
|
(b, m, n),
|
|
|
|
|
false,
|
|
|
|
|
size,
|
|
|
|
|
0,
|
|
|
|
|
data_type,
|
|
|
|
|
).unwrap();
|
|
|
|
|
|
|
|
|
|
let grid_size = MTLSize {
|
|
|
|
|
width: divide(n, n_group.into()),
|
|
|
|
|
height: divide(m, m_group.into()),
|
|
|
|
|
depth: grid_z as NSUInteger,
|
|
|
|
|
};
|
|
|
|
|
let group_size = MTLSize {
|
|
|
|
|
width: 32 * (m_splits as u64) * (n_splits as u64),
|
|
|
|
|
height: 1,
|
|
|
|
|
depth: 1,
|
|
|
|
|
};
|
|
|
|
|
// println!("grid size {grid_size:?} group size {group_size:?}");
|
|
|
|
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
// Create kernel
|
|
|
|
|
let matrix_multiplication = MatrixMultiplication::init(
|
|
|
|
|
&device,
|
|
|
|
|
a_trans,
|
|
|
|
|
b_trans,
|
|
|
|
|
m,
|
|
|
|
|
n,
|
|
|
|
|
k,
|
|
|
|
|
1.0,
|
|
|
|
|
0.0,
|
|
|
|
|
).unwrap();
|
|
|
|
|
|
|
|
|
|
matrix_multiplication.encode_to_command_buffer(
|
|
|
|
|
command_buffer,
|
|
|
|
|
&left_matrix,
|
|
|
|
|
&right_matrix,
|
|
|
|
|
&result_matrix,
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn create_matrix(
|
|
|
|
|
buffer: &Buffer,
|
|
|
|
|
(b, rows, columns): (NSUInteger, NSUInteger, NSUInteger),
|
|
|
|
|
transpose: bool,
|
|
|
|
|
size: NSUInteger,
|
|
|
|
|
offset: NSUInteger,
|
|
|
|
|
data_type: u32,
|
|
|
|
|
) -> Option<Matrix> {
|
|
|
|
|
let (rows, columns) = if transpose {
|
|
|
|
|
(columns, rows)
|
|
|
|
|
} else {
|
|
|
|
|
(rows, columns)
|
|
|
|
|
};
|
|
|
|
|
let descriptor = if b == 1 {
|
|
|
|
|
MatrixDescriptor::init_single(rows, columns, columns * size, data_type)
|
|
|
|
|
} else {
|
|
|
|
|
MatrixDescriptor::init_multiple(
|
|
|
|
|
rows,
|
|
|
|
|
columns,
|
|
|
|
|
b,
|
|
|
|
|
columns * size,
|
|
|
|
|
rows * columns * size,
|
|
|
|
|
data_type,
|
|
|
|
|
)
|
|
|
|
|
};
|
|
|
|
|
return Matrix::init_with_buffer_descriptor(&buffer, offset * size, &descriptor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
|
|
|
|
((m + b - 1) / b) as NSUInteger
|
|
|
|
|
}
|
|
|
|
|