241 Commits

Author SHA1 Message Date
cd96fa80da Add a scattered kv cache. (#2936)
* Add a scattered kv cache.

* Update some comments.
2025-05-01 10:20:48 +02:00
d4bac37a61 Fix the gumbel softmax by casting to f32. (#2928) 2025-04-28 19:48:51 +02:00
e3db30021f Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
2025-04-27 15:12:02 +02:00
21055b5697 Add PRelu operation (#2904)
* Add PRelu operation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-19 07:24:10 +02:00
7f0f83a7c1 Rotating kv cache positions (#2901)
* Retrieve the current positions for rotating KV caches.

* Add the function to the kv cache too.

* More testing.
2025-04-15 23:09:26 +02:00
2653002f29 Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling.

* Add a sampling test.

* Share the gumbel-softmax bits.
2025-04-14 15:42:42 +02:00
a52b76ae82 Expose the cudnn algo in the conv ops. (#2892)
* Set the algo.

* Expose the cudnn preferred algo for conv ops.
2025-04-14 08:25:32 +02:00
fb660b8d43 Add a cudnn feature to candle-nn/candle-transformers. (#2890) 2025-04-13 17:43:41 +02:00
34505fdf3a Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
2025-04-12 19:53:58 +02:00
acc5bd335f Cuda cleanup. (#2880)
* Cuda cleanup.

* More fixes.
2025-04-11 21:43:35 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
d9904a3baf Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
2025-04-03 09:12:19 +02:00
3afb04925a Allow for growing the default KV cache when needed. (#2810) 2025-03-16 17:30:25 +01:00
c930ab7e1a upgrade half library to fix rand (#2806)
fix lints
2025-03-14 09:01:54 +01:00
7c2449f623 Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-02-08 07:27:01 +01:00
ab9019425a Make the metal sdpa tests deterministic. (#2750) 2025-01-28 09:05:24 +01:00
1a32107fab Add a few metal gather ops. (#2740)
* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
2025-01-25 23:31:03 +01:00
85f0aaefe5 Add serde::serialize to activations. (#2732) 2025-01-22 10:23:34 +01:00
17cbbe4286 Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask

* Dispatch to the 2pass kernel

* Format
2025-01-16 11:30:10 +01:00
461e8c1685 ModernBERT model (#2713)
* layer_norm_no_bias

* Modernbert model.

* Format + cleanup error.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-01-13 08:39:27 +01:00
2344c4e4b8 Clippy fixes for 1.84. (#2710) 2025-01-10 10:15:15 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
b4deb5c5a9 Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
2024-11-26 22:52:53 +01:00
3769206583 Update docs (#2553)
* add module docs for candle-core

* doc each of the candle-nn modules and add the links to the doc page
2024-11-11 22:13:52 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
6454597943 Improved launch config for layer-norm/rms-norm. (#2591)
* Improved launch config for layer-norm/rms-norm.

* Add more testing for the fused layer/rms norm kernels.
2024-11-04 10:42:18 +01:00
56aacb05da Make the RNN configs accessible from the models. (#2541) 2024-10-04 14:22:23 +02:00
aa35bf2ff5 Add/lstm direction (#2455)
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-30 22:44:07 +02:00
683ab698de Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
d01207dbf3 Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
31a1075f4b onnx: implement LSTM op (#2268)
use candle-nn LSTM
2024-08-19 09:06:17 +02:00
6991a37b94 update: LSTMState and GRUState fields to be public (#2384) 2024-08-01 16:30:32 +02:00
0f5cbb08b3 Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
a0d03aded1 Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
2024-04-29 11:21:53 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
e5c8b88f90 Apply the cast before the scaling. (#2135) 2024-04-28 08:30:35 +02:00
b2e816752b Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
2024-04-22 18:52:00 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
2ac302a5d1 Add the rope THD kernel. (#2014)
* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
2024-04-05 08:32:58 +02:00
318d143224 Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels.

* Ensure contiguity for RNNs.

* Unrelated fix for segment anything.

* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00