mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Better batched matmul.
This commit is contained in:
@ -754,25 +754,24 @@ impl BackendStorage for MetalStorage {
|
|||||||
let k = k as NSUInteger;
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
let left_descriptor = if transpose_left {
|
let left_descriptor = if transpose_left {
|
||||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
|
||||||
} else {
|
} else {
|
||||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
|
||||||
};
|
};
|
||||||
let right_descriptor = if transpose_right {
|
let right_descriptor = if transpose_right {
|
||||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
|
||||||
} else {
|
} else {
|
||||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
|
||||||
};
|
};
|
||||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||||
|
|
||||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
for bi in 0..b {
|
|
||||||
// Create matrix objects
|
// Create matrix objects
|
||||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
lhs_l.start_offset() as NSUInteger * size,
|
||||||
&left_descriptor,
|
&left_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
@ -780,7 +779,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
})?;
|
})?;
|
||||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
rhs_l.start_offset() as NSUInteger * size,
|
||||||
&right_descriptor,
|
&right_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
@ -789,7 +788,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&out_buffer,
|
&out_buffer,
|
||||||
bi * m * n * size,
|
0,
|
||||||
&result_descriptor,
|
&result_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
@ -812,6 +811,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
})?;
|
})?;
|
||||||
|
matrix_multiplication.set_batch_size(b);
|
||||||
|
|
||||||
// Encode kernel to command buffer
|
// Encode kernel to command buffer
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
@ -820,7 +820,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&right_matrix,
|
&right_matrix,
|
||||||
&result_matrix,
|
&result_matrix,
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
|
Reference in New Issue
Block a user