* Start updating to cudarc 0.14.
* Adapt a couple more things.
* And a couple more fixes.
* More tweaks.
* And a couple more fixes.
* Bump the major version number.
* Proper module system for the cuda kernels.
* Proper ptx loading.
* Launch the sort kernel.
* Custom op.
* Start using the builder pattern.
* More builder.
* More builder.
* Get candle-core to compile.
* Get the tests to pass.
* Get candle-nn to work too.
* Support for custom cuda functions.
* cudnn fixes.
* Get flash attn to run.
* Switch the crate versions to be alpha.
* Bump the ug dependency.
* update to cudarc to v0.13.5 to support cuda 12.8
* Bump the crate version.
---------
Co-authored-by: Michael McCulloch <michael.james.mcculloch@fastmail.com>
* Improve reduce perf and add contiguous impl
* Improve arg reduce and add contiguous impl
* Improve softmax kernel. 33%-39% higher thrpt
* fmt
* Fixed all bugs. Improved code quality. Added tests.
* Stash for debugging
* Stash for debugging 2
* Fixing argmax bug and improve performance
Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
* Fix test and add is_valid_simgroup_reduce_type trait
* Online softmax. Improved threadgroup reduce. Tidying up a bit.
* Remove redundant threadgroup_barrier from arg reduce
* Mostly tidying up. Some improvements
* Simplify indexed struct
* tidying
* Reuse operation operator instead of passing it in as a parameter
* Fix how operators are applied to indexed<vec<T,N>>
* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.
* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.
* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math
* Use constant for input instead of const device. Fix strided reduce.
* Use contiguous reduce in tests
* Rename finalize -> to_scalar
* Support integer types max/min (switch with trait-inferred impl later)
* Was worried I was skipping work -> shuffling the 1D test cases
* Add build.rs to avoid metal kernel jit compile overhead
* Improve build. Extract utils
* Compile metal kernels for both macos and ios
* Fixed over xmas and then forgot about it
* Add calculate_reduce_threads util
* Remove old reduce.metal
* Improve f16/bf16 softmax precision by accumulating in f32
* Remove build.rs (for now)
* Move softmax bench to candle-nn
* Remove redundant thread calc util fn
* Use uint over ushort for indices etc
* Use fast exp in MDReduceOp
* Remove nested metal define for softmax
* Fix some clippy lint.
---------
Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add some metal sort kernels imported from MLX.
* Add another test.
* Start adding the multiblock version.
* Proper kernel names.
* Split out the main metal file.
* Multi-block sort.
* More sorting.
* DType parametrization.
* Add a larger test.
* Add some fast Metal MLX SDPA kernels (#32)
* Sketch the sdpa kernel
* Add full sdpa kernel,
* Add test
* Add vectorized kernel for decoding
* Update tests
* Add some docs
* Fix sdpa_vector names
* Add softcapping for vectorized sdpa
* Add softcapping for full sdpa
* Add support for head dim 32, 96, 256
* Add support for head dim 32, 96, 256
* Update docs
* Add update notice
* Clippy and format
* Conditional compilation for bf16
* Use it in quantized llama
* Some review comments
* Use set_params!
* Remove unused
* Remove feature
* Fix metal sdpa for v stride
* Remove comma
* Add the dim method to layout and shape.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
* Include the MLX gemm kernels.
* Clippy lints.
* Export the gemm_f32 kernel.
* Add the f16/bf16 variants.
* Add the initial dispatch code.
* More plugging of the mlx kernels.
* Add a currently broken test.
* Tweaks.
* Bugfix + get the tests to pass.
* Enable the gemm bf16 tests.
* Add some randomized tests.
* Update candle-metal-kernels/src/lib.rs
Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
* More fixes.
* More clippy fixes.
---------
Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
* Add the layernorm cuda kernels.
* Dedicated layer norm op.
* Add the slower variant.
* Plug the cuda implementation.
* Add the metal variant.
* Add a dedicated test.
* Bugfix.
* Separate quantized phi-3 implementation.
* Integrate the quantized phi3 model.=
* Small fixes, get the generation to work properly.
* Keep the old llama implementation around.
* Change the default.
* add sigmoid op
* small fix
* add as a method on `Tensor`
* implement gradient calculation for sigmoid
* add sigmoid tests
* we should have a specialized op for this
* fix clippy
* fix clippy 2
* Revert all previous commits in favor of a `CustomOp` based solution
* use `CustomOp1` implementation
* fix rustfmt
* experimental add metal impl
* add cuda kernel impl
* fix fmt
* Add a test + reduce some cuda duplication.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* Add the argsort cuda kernels.
* CPU version of arg-sort.
* Hook the cuda kernel + rework the cpu bits.
* Add some dedicated test.
* Working cuda kernel.
* Metal kernel.
* Metal adjustments.
* Bugfix.
* Use the fast rope in qwen.
* Rework the expert selection in qwen.
* add basic unary bench for sqrt
* process unary commands in tiles of 4
* re-enable all benchmarks
* rename helper to unary
* modify approach to split up tiled and non-tiled operations
* undo bench ignore for other tests
* update tile size to 2
* only perform the optimization on the contiguous even numbered element case