Commit Graph

144 Commits

Author SHA1 Message Date
1a32107fab Add a few metal gather ops. (#2740)
* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
2025-01-25 23:31:03 +01:00
17cbbe4286 Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask

* Dispatch to the 2pass kernel

* Format
2025-01-16 11:30:10 +01:00
236c35e578 Bump the caret version to 0.8.2. (#2703) 2025-01-07 15:50:16 +01:00
67cab7d6b8 Bump the crate version to 0.8.1. (#2662) 2024-12-07 17:03:53 +01:00
6f715f9256 add scatter add (#2656) 2024-12-01 18:39:38 +01:00
dba7a9c93e add u32 - U32 gather (#2653) 2024-11-30 23:18:07 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
06350c31c7 Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
2024-11-12 17:10:12 +01:00
9453cc3095 Bump the crate version to 0.8.0. (#2612) 2024-11-12 14:11:46 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
0e2c8c17fb UG metal integration. (#2580) 2024-10-27 15:20:37 +01:00
fd08d3d0a4 Tweak some metal tests. (#2528) 2024-10-02 10:22:31 +02:00
a2bcc227df Efficient implementation of Tensor::ones() for metal (#2512)
* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
2024-10-01 19:11:59 +02:00
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 2024-09-29 10:56:50 +02:00
8097559c1a Move the candle version to 0.7.1. (#2495) 2024-09-22 20:44:39 +02:00
c2fca0ca11 Bump the crate version. (#2491) 2024-09-21 15:13:12 +02:00
844d45cde4 Bugfix for the metal elu kernel. (#2490)
* Bugfix for the metal elu kernel.

* Add a test.
2024-09-21 15:03:19 +02:00
af2104078f Metal commands refactoring (#2489)
* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
2024-09-21 13:18:42 +02:00
c09afc211c Fix for metal tanh. (#2475) 2024-09-13 07:08:36 +02:00
0cb0bd1dfa Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark.

* More benchmarks.
2024-09-11 22:52:37 +02:00
5635650d38 Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-09-11 16:56:48 +02:00
6070278a31 Bump the version to 0.6.1. (#2438) 2024-08-22 09:23:52 +02:00
0fcb40b229 Revert the bf16 gemm metal changes for now. (#2386) 2024-08-01 23:08:47 +02:00
fea46cb719 Metal bgemm min changes (#2364)
* Add updated mfa metallib

* Add bgemm and tests
2024-08-01 10:06:04 +02:00
8696cf6494 Enable the affine kernel for u8/u32. (#2376) 2024-08-01 10:03:11 +02:00
ddafc61055 Use RAII for terminating the encoding. (#2353) 2024-07-24 16:29:56 +02:00
a925ae6bc6 Use a trait for the encoder provider (so that encoder can ultimately be reused). (#2352) 2024-07-24 09:27:30 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
b13a82a438 Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
2024-05-04 10:14:57 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
7f354473cf Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
2024-04-07 12:34:16 +02:00
2ac302a5d1 Add the rope THD kernel. (#2014)
* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
2024-04-05 08:32:58 +02:00
c5626b8271 Add support for "sign" on tensors (#2012)
* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-04 22:32:47 +02:00
5aebe53dd2 update dtypes checks for several metal operations (#2010) 2024-04-04 18:39:06 +02:00
f76bb7794a Bumping the version number to 0.5.0. (#2009) 2024-04-04 17:48:45 +02:00
1e46cf8b19 Minor cleanups in reduce.metal. (#2004) 2024-04-04 08:26:02 +02:00
bd8db2a771 refactor to reduce the amount of code wrapped in template syntax (#2002) 2024-04-04 08:13:12 +02:00
b3484e7a5e Fix for the RWKV models. (#1955)
* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
2024-03-28 10:17:38 +01:00