Better batched matmul.

This commit is contained in:
Nicolas Patry
2023-11-17 10:36:57 +01:00
parent 2801541e5f
commit a0010898cc

View File

@ -754,25 +754,24 @@ impl BackendStorage for MetalStorage {
let k = k as NSUInteger;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
lhs_l.start_offset() as NSUInteger * size,
&left_descriptor,
)
.ok_or_else(|| {
@ -780,7 +779,7 @@ impl BackendStorage for MetalStorage {
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
rhs_l.start_offset() as NSUInteger * size,
&right_descriptor,
)
.ok_or_else(|| {
@ -789,7 +788,7 @@ impl BackendStorage for MetalStorage {
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
0,
&result_descriptor,
)
.ok_or_else(|| {
@ -812,6 +811,7 @@ impl BackendStorage for MetalStorage {
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
@ -820,7 +820,6 @@ impl BackendStorage for MetalStorage {
&right_matrix,
&result_matrix,
);
}
Ok(Self {
buffer: out_buffer,