mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Better batched matmul.
This commit is contained in:
@ -754,73 +754,72 @@ impl BackendStorage for MetalStorage {
|
|||||||
let k = k as NSUInteger;
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
let left_descriptor = if transpose_left {
|
let left_descriptor = if transpose_left {
|
||||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
|
||||||
} else {
|
} else {
|
||||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
|
||||||
};
|
};
|
||||||
let right_descriptor = if transpose_right {
|
let right_descriptor = if transpose_right {
|
||||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
|
||||||
} else {
|
} else {
|
||||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
|
||||||
};
|
};
|
||||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||||
|
|
||||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
for bi in 0..b {
|
// Create matrix objects
|
||||||
// Create matrix objects
|
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
&self.buffer,
|
||||||
&self.buffer,
|
lhs_l.start_offset() as NSUInteger * size,
|
||||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
&left_descriptor,
|
||||||
&left_descriptor,
|
)
|
||||||
)
|
.ok_or_else(|| {
|
||||||
.ok_or_else(|| {
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
})?;
|
||||||
})?;
|
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
&rhs.buffer,
|
||||||
&rhs.buffer,
|
rhs_l.start_offset() as NSUInteger * size,
|
||||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
&right_descriptor,
|
||||||
&right_descriptor,
|
)
|
||||||
)
|
.ok_or_else(|| {
|
||||||
.ok_or_else(|| {
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
})?;
|
||||||
})?;
|
|
||||||
|
|
||||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&out_buffer,
|
&out_buffer,
|
||||||
bi * m * n * size,
|
0,
|
||||||
&result_descriptor,
|
&result_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let alpha = 1.0f64;
|
let alpha = 1.0f64;
|
||||||
let beta = 0.0f64;
|
let beta = 0.0f64;
|
||||||
// Create kernel
|
// Create kernel
|
||||||
let matrix_multiplication = MatrixMultiplication::init(
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
&self.device,
|
&self.device,
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
m,
|
m,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
alpha,
|
alpha,
|
||||||
beta,
|
beta,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
})?;
|
})?;
|
||||||
|
matrix_multiplication.set_batch_size(b);
|
||||||
|
|
||||||
// Encode kernel to command buffer
|
// Encode kernel to command buffer
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&left_matrix,
|
&left_matrix,
|
||||||
&right_matrix,
|
&right_matrix,
|
||||||
&result_matrix,
|
&result_matrix,
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
|
Reference in New Issue
Block a user