Better batched matmul.

This commit is contained in:
Nicolas Patry
2023-11-17 10:36:57 +01:00
parent 2801541e5f
commit a0010898cc

View File

@ -754,73 +754,72 @@ impl BackendStorage for MetalStorage {
let k = k as NSUInteger; let k = k as NSUInteger;
let left_descriptor = if transpose_left { let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id) MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
} else { } else {
MatrixDescriptor::init_single(m, k, k * size, type_id) MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
}; };
let right_descriptor = if transpose_right { let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id) MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
} else { } else {
MatrixDescriptor::init_single(k, n, n * size, type_id) MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
}; };
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype); let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
for bi in 0..b { // Create matrix objects
// Create matrix objects let left_matrix = Matrix::init_with_buffer_descriptor(
let left_matrix = Matrix::init_with_buffer_descriptor( &self.buffer,
&self.buffer, lhs_l.start_offset() as NSUInteger * size,
(bi * stride_left + lhs_l.start_offset() as u64) * size, &left_descriptor,
&left_descriptor, )
) .ok_or_else(|| {
.ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string())
MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?;
})?; let right_matrix = Matrix::init_with_buffer_descriptor(
let right_matrix = Matrix::init_with_buffer_descriptor( &rhs.buffer,
&rhs.buffer, rhs_l.start_offset() as NSUInteger * size,
(bi * stride_right + rhs_l.start_offset() as u64) * size, &right_descriptor,
&right_descriptor, )
) .ok_or_else(|| {
.ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string())
MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?;
})?;
let result_matrix = Matrix::init_with_buffer_descriptor( let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer, &out_buffer,
bi * m * n * size, 0,
&result_descriptor, &result_descriptor,
) )
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
let alpha = 1.0f64; let alpha = 1.0f64;
let beta = 0.0f64; let beta = 0.0f64;
// Create kernel // Create kernel
let matrix_multiplication = MatrixMultiplication::init( let matrix_multiplication = MatrixMultiplication::init(
&self.device, &self.device,
transpose_left, transpose_left,
transpose_right, transpose_right,
m, m,
n, n,
k, k,
alpha, alpha,
beta, beta,
) )
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer // Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer( matrix_multiplication.encode_to_command_buffer(
&command_buffer, &command_buffer,
&left_matrix, &left_matrix,
&right_matrix, &right_matrix,
&result_matrix, &result_matrix,
); );
}
Ok(Self { Ok(Self {
buffer: out_buffer, buffer: out_buffer,