mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Better batched matmul.
This commit is contained in:
@ -754,73 +754,72 @@ impl BackendStorage for MetalStorage {
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_descriptor = if transpose_left {
|
||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
|
||||
};
|
||||
let right_descriptor = if transpose_right {
|
||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
|
||||
};
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
for bi in 0..b {
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||
&left_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&rhs.buffer,
|
||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||
&right_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
lhs_l.start_offset() as NSUInteger * size,
|
||||
&left_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&rhs.buffer,
|
||||
rhs_l.start_offset() as NSUInteger * size,
|
||||
&right_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&out_buffer,
|
||||
bi * m * n * size,
|
||||
&result_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&out_buffer,
|
||||
0,
|
||||
&result_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
matrix_multiplication.set_batch_size(b);
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
}
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
|
Reference in New Issue
Block a user