mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
mps matmul
This commit is contained in:
@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
metal = { version = "0.27.0", features = ["mps"], package = "candle-metal" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -10,7 +10,7 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
metal = { version = "0.27.0", features = ["mps"], package="candle-metal" }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -5,6 +5,7 @@ use metal::{
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
@ -1052,123 +1053,206 @@ pub fn call_gemm(
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let d_trans = false;
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let m_simd = 16;
|
||||
let n_simd = 16;
|
||||
let k_simd = 16;
|
||||
let m_splits = 2;
|
||||
let n_splits = 2;
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
// let d_trans = false;
|
||||
// let alpha = 1.0f32;
|
||||
// let beta = 0.0f32;
|
||||
// let batched = b > 1;
|
||||
// let fused_activation = false;
|
||||
// let fused_bias = false;
|
||||
// let m_simd = 16;
|
||||
// let n_simd = 16;
|
||||
// let k_simd = 16;
|
||||
// let m_splits = 2;
|
||||
// let n_splits = 2;
|
||||
// let constants = Some(ConstantValues::new(vec![
|
||||
// (0, Value::USize(m)),
|
||||
// (1, Value::USize(n)),
|
||||
// (2, Value::USize(k)),
|
||||
// (10, Value::Bool(a_trans)),
|
||||
// (11, Value::Bool(b_trans)),
|
||||
// (13, Value::Bool(d_trans)),
|
||||
// (20, Value::F32(alpha)),
|
||||
// (21, Value::F32(beta)),
|
||||
// (100, Value::Bool(batched)),
|
||||
// (101, Value::Bool(fused_activation)),
|
||||
// // Garbage
|
||||
// (102, Value::Bool(false)),
|
||||
// (103, Value::Bool(false)),
|
||||
// (113, Value::Bool(false)),
|
||||
// (50_000, Value::Bool(false)),
|
||||
// // End garbage
|
||||
// (200, Value::U16(m_simd)),
|
||||
// (201, Value::U16(n_simd)),
|
||||
// (202, Value::U16(k_simd)),
|
||||
// (210, Value::U16(m_splits)),
|
||||
// (211, Value::U16(n_splits)),
|
||||
// (50_001, Value::Bool(fused_bias)),
|
||||
// ]));
|
||||
// let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
// let m_group = m_simd * m_splits;
|
||||
// let n_group = n_simd * n_splits;
|
||||
//
|
||||
// let a_block_length = m_group * k_simd;
|
||||
// let b_block_length = k_simd * n_group;
|
||||
//
|
||||
// let mut block_elements = a_block_length + b_block_length;
|
||||
// if (m % 8 != 0) && (n % 8 != 0) {
|
||||
// let c_block_length = m_group * n_group;
|
||||
// block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
// }
|
||||
// if fused_bias {
|
||||
// if d_trans {
|
||||
// block_elements = std::cmp::max(block_elements, m_group);
|
||||
// } else {
|
||||
// block_elements = std::cmp::max(block_elements, n_group);
|
||||
// }
|
||||
// }
|
||||
// let bytes = match name {
|
||||
// "sgemm" => 4,
|
||||
// "hgemm" => 2,
|
||||
// other => {
|
||||
// return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
// "{other} is not a valid kernel for gemm"
|
||||
// )));
|
||||
// }
|
||||
// };
|
||||
// let block_bytes = block_elements * bytes;
|
||||
//
|
||||
// let encoder = command_buffer.new_compute_command_encoder();
|
||||
// encoder.wait_for_fence(&kernels.fence);
|
||||
// encoder.set_compute_pipeline_state(&pipeline);
|
||||
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
// encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
// encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
// encoder.set_buffer(2, Some(output), 0);
|
||||
// // TODO Tensor D
|
||||
//
|
||||
// let grid_z = b;
|
||||
// if batched {
|
||||
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_c = m * n * bytes as usize;
|
||||
// // TODO byte_stride_d
|
||||
// let byte_stride_d = 0;
|
||||
//
|
||||
// let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||
// for i in 0..b {
|
||||
// buffer.push((i * byte_stride_a) as u64);
|
||||
// buffer.push((i * byte_stride_b) as u64);
|
||||
// buffer.push((i * byte_stride_c) as u64);
|
||||
// buffer.push((i * byte_stride_d) as u64);
|
||||
// }
|
||||
// encoder.set_bytes(
|
||||
// 10,
|
||||
// (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
// buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// let grid_size = MTLSize {
|
||||
// width: divide(n, n_group.into()),
|
||||
// height: divide(m, m_group.into()),
|
||||
// depth: grid_z as NSUInteger,
|
||||
// };
|
||||
// let group_size = MTLSize {
|
||||
// width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
// height: 1,
|
||||
// depth: 1,
|
||||
// };
|
||||
// // println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
// encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
// encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
// encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
// encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
// encoder.update_fence(&kernels.fence);
|
||||
// encoder.end_encoding();
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
let (b, m, n, k) = (
|
||||
b as NSUInteger,
|
||||
m as NSUInteger,
|
||||
n as NSUInteger,
|
||||
k as NSUInteger,
|
||||
);
|
||||
|
||||
let mut block_elements = a_block_length + b_block_length;
|
||||
if (m % 8 != 0) && (n % 8 != 0) {
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
} else {
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let block_bytes = block_elements * bytes;
|
||||
let (size, data_type) = if name == "sgemm" { (4, 0x10000000 | 32) } else { (2, 0x10000000 | 16) };
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
// TODO Tensor D
|
||||
let left_matrix = create_matrix(
|
||||
lhs_buffer,
|
||||
(b, m, k),
|
||||
a_trans,
|
||||
size,
|
||||
lhs_offset as NSUInteger,
|
||||
data_type,
|
||||
).unwrap();
|
||||
|
||||
let grid_z = b;
|
||||
if batched {
|
||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_c = m * n * bytes as usize;
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
let right_matrix = create_matrix(
|
||||
rhs_buffer,
|
||||
(b, k, n),
|
||||
b_trans,
|
||||
size,
|
||||
rhs_offset as NSUInteger,
|
||||
data_type,
|
||||
).unwrap();
|
||||
|
||||
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||
for i in 0..b {
|
||||
buffer.push((i * byte_stride_a) as u64);
|
||||
buffer.push((i * byte_stride_b) as u64);
|
||||
buffer.push((i * byte_stride_c) as u64);
|
||||
buffer.push((i * byte_stride_d) as u64);
|
||||
}
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
);
|
||||
}
|
||||
let result_matrix = create_matrix(
|
||||
output,
|
||||
(b, m, n),
|
||||
false,
|
||||
size,
|
||||
0,
|
||||
data_type,
|
||||
).unwrap();
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: divide(n, n_group.into()),
|
||||
height: divide(m, m_group.into()),
|
||||
depth: grid_z as NSUInteger,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&device,
|
||||
a_trans,
|
||||
b_trans,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
0.0,
|
||||
).unwrap();
|
||||
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_matrix(
|
||||
buffer: &Buffer,
|
||||
(b, rows, columns): (NSUInteger, NSUInteger, NSUInteger),
|
||||
transpose: bool,
|
||||
size: NSUInteger,
|
||||
offset: NSUInteger,
|
||||
data_type: u32,
|
||||
) -> Option<Matrix> {
|
||||
let (rows, columns) = if transpose {
|
||||
(columns, rows)
|
||||
} else {
|
||||
(rows, columns)
|
||||
};
|
||||
let descriptor = if b == 1 {
|
||||
MatrixDescriptor::init_single(rows, columns, columns * size, data_type)
|
||||
} else {
|
||||
MatrixDescriptor::init_multiple(
|
||||
rows,
|
||||
columns,
|
||||
b,
|
||||
columns * size,
|
||||
rows * columns * size,
|
||||
data_type,
|
||||
)
|
||||
};
|
||||
return Matrix::init_with_buffer_descriptor(&buffer, offset * size, &descriptor);
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
Reference in New Issue
Block a user