Better batched matmul.

This commit is contained in:
Nicolas Patry
2023-11-17 10:36:57 +01:00
parent 2801541e5f
commit a0010898cc

View File

@ -754,73 +754,72 @@ impl BackendStorage for MetalStorage {
let k = k as NSUInteger;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
lhs_l.start_offset() as NSUInteger * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
rhs_l.start_offset() as NSUInteger * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
0,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
Ok(Self {
buffer: out_buffer,