Better batched matmul.

This commit is contained in:
Nicolas Patry
2023-11-17 10:36:57 +01:00
parent 2801541e5f
commit a0010898cc

View File

@ -754,25 +754,24 @@ impl BackendStorage for MetalStorage {
let k = k as NSUInteger; let k = k as NSUInteger;
let left_descriptor = if transpose_left { let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id) MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
} else { } else {
MatrixDescriptor::init_single(m, k, k * size, type_id) MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
}; };
let right_descriptor = if transpose_right { let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id) MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
} else { } else {
MatrixDescriptor::init_single(k, n, n * size, type_id) MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
}; };
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype); let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects // Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor( let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer, &self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size, lhs_l.start_offset() as NSUInteger * size,
&left_descriptor, &left_descriptor,
) )
.ok_or_else(|| { .ok_or_else(|| {
@ -780,7 +779,7 @@ impl BackendStorage for MetalStorage {
})?; })?;
let right_matrix = Matrix::init_with_buffer_descriptor( let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer, &rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size, rhs_l.start_offset() as NSUInteger * size,
&right_descriptor, &right_descriptor,
) )
.ok_or_else(|| { .ok_or_else(|| {
@ -789,7 +788,7 @@ impl BackendStorage for MetalStorage {
let result_matrix = Matrix::init_with_buffer_descriptor( let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer, &out_buffer,
bi * m * n * size, 0,
&result_descriptor, &result_descriptor,
) )
.ok_or_else(|| { .ok_or_else(|| {
@ -812,6 +811,7 @@ impl BackendStorage for MetalStorage {
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer // Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer( matrix_multiplication.encode_to_command_buffer(
@ -820,7 +820,6 @@ impl BackendStorage for MetalStorage {
&right_matrix, &right_matrix,
&result_matrix, &result_matrix,
); );
}
Ok(Self { Ok(Self {
buffer: out_buffer, buffer: out_buffer,