Integrate the MLX gemm kernels (#2468)

* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Laurent Mazare
2024-09-11 15:56:48 +01:00
committed by GitHub
parent 13b2a8a4a0
commit 5635650d38
5 changed files with 1874 additions and 55 deletions

View File

@ -165,7 +165,7 @@ pub trait EncoderProvider {
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
fn encoder(&self) -> Self::Encoder<'_>;
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
@ -178,7 +178,7 @@ impl<'a> Drop for WrappedEncoder<'a> {
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
&self.0
self.0
}
}
@ -186,7 +186,7 @@ impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder())
}
}
@ -195,7 +195,7 @@ impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder())
}
}