Commit Graph

156 Commits

Author SHA1 Message Date
9dbaf958dc Add an enum for scalar values. (#2909)
* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
2025-04-18 22:13:38 +02:00
e4e7b0b2da Use cudarc 0.16. (#2900)
* Use cudarc 0.16.

* Allow for disabling event tracking.

* Tweaks.

* Bump the ug version.

* And bump the candle version too.
2025-04-15 21:40:18 +02:00
1d1d6d4fe6 Bump the crate version. (#2895) 2025-04-14 15:52:11 +02:00
19fb6dac1f Bump the crate version. (#2881) 2025-04-11 22:28:21 +02:00
d9904a3baf Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
2025-04-03 09:12:19 +02:00
468d1d525f Bump the crate version to 0.8.4. (#2808) 2025-03-15 07:42:24 +01:00
fd7f7242a1 Bump the crate version to 0.8.3 (#2772)
* update to cudarc to v0.13.5 to support cuda 12.8

* Bump the crate version.

---------

Co-authored-by: Michael McCulloch <michael.james.mcculloch@fastmail.com>
2025-02-15 15:54:48 +01:00
7c2449f623 Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-02-08 07:27:01 +01:00
d2c53f4f2f Remove the MFA gemm library. (#2755) 2025-01-28 21:48:17 +01:00
8f20f2a722 Add the MLX merge sort kernels (#2751)
* Add some metal sort kernels imported from MLX.

* Add another test.

* Start adding the multiblock version.

* Proper kernel names.

* Split out the main metal file.

* Multi-block sort.

* More sorting.

* DType parametrization.

* Add a larger test.
2025-01-28 14:09:43 +01:00
da02b59516 Allow using composed strings as metal kernel names. (#2747) 2025-01-27 22:40:12 +01:00
27996a1a9e Remove the old MFA gemm kernels. (#2742)
* Remove the old MFA gemm kernels.

* Use bf16 in helium on metal.
2025-01-26 20:36:31 +01:00
1a32107fab Add a few metal gather ops. (#2740)
* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
2025-01-25 23:31:03 +01:00
17cbbe4286 Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask

* Dispatch to the 2pass kernel

* Format
2025-01-16 11:30:10 +01:00
236c35e578 Bump the caret version to 0.8.2. (#2703) 2025-01-07 15:50:16 +01:00
67cab7d6b8 Bump the crate version to 0.8.1. (#2662) 2024-12-07 17:03:53 +01:00
6f715f9256 add scatter add (#2656) 2024-12-01 18:39:38 +01:00
dba7a9c93e add u32 - U32 gather (#2653) 2024-11-30 23:18:07 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
06350c31c7 Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
2024-11-12 17:10:12 +01:00
9453cc3095 Bump the crate version to 0.8.0. (#2612) 2024-11-12 14:11:46 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
0e2c8c17fb UG metal integration. (#2580) 2024-10-27 15:20:37 +01:00
fd08d3d0a4 Tweak some metal tests. (#2528) 2024-10-02 10:22:31 +02:00
a2bcc227df Efficient implementation of Tensor::ones() for metal (#2512)
* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
2024-10-01 19:11:59 +02:00
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 2024-09-29 10:56:50 +02:00
8097559c1a Move the candle version to 0.7.1. (#2495) 2024-09-22 20:44:39 +02:00
c2fca0ca11 Bump the crate version. (#2491) 2024-09-21 15:13:12 +02:00
844d45cde4 Bugfix for the metal elu kernel. (#2490)
* Bugfix for the metal elu kernel.

* Add a test.
2024-09-21 15:03:19 +02:00
af2104078f Metal commands refactoring (#2489)
* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
2024-09-21 13:18:42 +02:00
c09afc211c Fix for metal tanh. (#2475) 2024-09-13 07:08:36 +02:00
0cb0bd1dfa Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark.

* More benchmarks.
2024-09-11 22:52:37 +02:00
5635650d38 Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-09-11 16:56:48 +02:00
6070278a31 Bump the version to 0.6.1. (#2438) 2024-08-22 09:23:52 +02:00
0fcb40b229 Revert the bf16 gemm metal changes for now. (#2386) 2024-08-01 23:08:47 +02:00
fea46cb719 Metal bgemm min changes (#2364)
* Add updated mfa metallib

* Add bgemm and tests
2024-08-01 10:06:04 +02:00
8696cf6494 Enable the affine kernel for u8/u32. (#2376) 2024-08-01 10:03:11 +02:00
ddafc61055 Use RAII for terminating the encoding. (#2353) 2024-07-24 16:29:56 +02:00
a925ae6bc6 Use a trait for the encoder provider (so that encoder can ultimately be reused). (#2352) 2024-07-24 09:27:30 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
b13a82a438 Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
2024-05-04 10:14:57 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00