diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7145d42b..349edc49 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -754,73 +754,72 @@ impl BackendStorage for MetalStorage { let k = k as NSUInteger; let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) + MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id) } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) + MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id) }; let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) + MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id) } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) + MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id) }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); + let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); let out_buffer = self.device.new_buffer(elem_count, self.dtype); let command_buffer = self.device.command_buffer(); - for bi in 0..b { - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor( - &self.buffer, - (bi * stride_left + lhs_l.start_offset() as u64) * size, - &left_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor( - &rhs.buffer, - (bi * stride_right + rhs_l.start_offset() as u64) * size, - &right_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + // Create matrix objects + let left_matrix = Matrix::init_with_buffer_descriptor( + &self.buffer, + lhs_l.start_offset() as NSUInteger * size, + &left_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + let right_matrix = Matrix::init_with_buffer_descriptor( + &rhs.buffer, + rhs_l.start_offset() as NSUInteger * size, + &right_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; - let result_matrix = Matrix::init_with_buffer_descriptor( - &out_buffer, - bi * m * n * size, - &result_descriptor, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let result_matrix = Matrix::init_with_buffer_descriptor( + &out_buffer, + 0, + &result_descriptor, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let alpha = 1.0f64; + let beta = 0.0f64; + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &self.device, + transpose_left, + transpose_right, + m, + n, + k, + alpha, + beta, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + matrix_multiplication.set_batch_size(b); - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( - &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); - } + // Encode kernel to command buffer + matrix_multiplication.encode_to_command_buffer( + &command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); Ok(Self { buffer: out_buffer,