Commit Graph

394 Commits

Author SHA1 Message Date
a52b76ae82 Expose the cudnn algo in the conv ops. (#2892)
* Set the algo.

* Expose the cudnn preferred algo for conv ops.
2025-04-14 08:25:32 +02:00
fb660b8d43 Add a cudnn feature to candle-nn/candle-transformers. (#2890) 2025-04-13 17:43:41 +02:00
eb478ece92 Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
2025-04-11 13:25:39 +02:00
d339b01726 Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872) 2025-04-08 06:12:14 +02:00
e3370c6316 Add the SNAC audio tokenizer. (#2869)
* Add the SNAC audio tokenizer.

* More snac.

* Again more snac.

* Add some example code for snac.

* Get the weights to load.

* Add to the snac model.

* Fixes.

* Get round-tripping to work.

* Save/load code files.

* Clippy fix.

* Fmt fix.
2025-04-06 22:15:36 +02:00
cf9d7bf24c Add the CSM model. (#2862)
* Add the CSM model.

* Add some code to load the model.

* Load the text tokenizer.

* Add frame generation.

* Get the sampling to work.

* Rope fix.

* Autoregressive generation.

* Generate some audio file.

* Use the actual prompt.

* Support multiple turns.

* Add a very barebone readme.

* Move some of the shared bits to the model.
2025-04-04 06:48:03 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
d6db305829 Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt

* lint

* seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-02 23:50:14 +02:00
c930ab7e1a upgrade half library to fix rand (#2806)
fix lints
2025-03-14 09:01:54 +01:00
111edbc4ea Gemma 3 initial setup (text only). (#2802)
* Gemma 3 initial setup (text only).

* Use the rotating kv cache for the sliding window.
2025-03-14 07:56:02 +01:00
e286cf7cc9 Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models.

* Bump the tokenizers dependency.

* Add a v2 model.

* Support more v2 model.s
2025-03-09 14:01:09 +01:00
e4ffb85228 Add ModernBert sentency classifier (#2796) 2025-03-08 14:48:22 +01:00
37db86ff79 Allow ModernBert to be used to generate embeddings. (#2791) 2025-03-03 12:39:04 +01:00
e6cc76fc37 Implement DeepSeek V2 (#2744)
* Add deepseek v2

* Fix

* Remove unused

* Add kv cache

* Remove from cargo.toml

* Fix dtype selection logic

* Fix unnecessary u32->f32->gather->u32

* Remove fromstr impl

* Use local scopes for some clarity

* Typo

* Repeat k_pe

* Chain calls to remove mut

* Actually, remove all muts

* Update readme
2025-02-19 10:51:01 +01:00
2423d633fc add dynamic position encoding to Siglip (#2770)
* add dynamic position encoding

* remove debug messages
2025-02-14 13:50:50 +01:00
43017539ab Adds DebertaV2/V3 (#2743)
* Adds DebertaV2/V3

* Fixes all clippy warnings

* Typos.

* Addresses PR review findings. Some refactorings

* Avoid some unwrap/unwrap_or.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-29 08:59:28 +01:00
333d94a19a fix: fix the codegeex4 model examples and transformers model (#2738)
* Update main.rs

* Update codegeex4_9b.rs

* Get things to compile.

* Add some default for when rope_ratio is missing.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-25 17:41:12 +01:00
e4c3a71f11 Fix GLM4 alignment issue (#2723)
* Fix GLM4 alignment issue

* Cleanups.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-20 22:51:46 +01:00
309cd0f7c7 Add the helium model. (#2715) 2025-01-13 17:39:49 +01:00
ab7ff7081e Fixes for running Phi-4 quantized. (#2714) 2025-01-13 14:35:33 +01:00
461e8c1685 ModernBERT model (#2713)
* layer_norm_no_bias

* Modernbert model.

* Format + cleanup error.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-01-13 08:39:27 +01:00
57f41da13b Fix mistral attention on Metal (#2699)
Co-authored-by: Luka Zakrajsek <luka.zakrajsek@soniox.com>
2025-01-04 16:11:20 +01:00
cbaa0ad46f UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
2025-01-01 21:34:17 +01:00
91f1f019b1 Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base

* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions

- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-30 11:16:57 +01:00
cd639131f0 Fix bug in whisper transformer (#2681)
* Fix bug in whisper transformer
- due to num_threads going to zero
in single threaded case

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-24 13:58:21 +01:00
1be6b090c7 Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid

* pass in subsampled positions

* clippy fix

* clippy fix
2024-12-23 13:22:35 +01:00
62ced44ea9 Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.

* Switch two unwrap to context.
2024-12-22 09:18:13 +01:00
5c2f893e5a make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-21 12:06:03 +01:00
1807be84f4 Change/bert encoder public (#2658)
* change: BertEncoder struct to public

* change: make certain fields in Config struct public

* change: all fields in bert config struct to be public

* change: add clone to bert encoder and others

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-04 21:22:30 +01:00
145aa7193c Add Nvembed v2 model (#2649)
* Update mod.rs

* Create mod.rs

* Create decoder.rs

* Create model.rs

* Create main.rs

* Create README.md

* Update README.md

* Update main.rs

* Update and rename decoder.rs to embedding.rs

* Update mod.rs

* Update model.rs
2024-12-03 10:56:01 +01:00
4f59ed38b0 Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant

* Unified stella

* WIP: Unified Stella

* Combined stella for both 1.5B and 400M variants

* Cargo fmt for the CI

* removed redundant stella-400m model and example after merge into stella-en-v5

* cargo fmt --all

---------

Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-11-29 09:01:08 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
3159f91b90 20241118 docs (#2629)
* module docs

* varbuilder gguf docs

* add a link to gguf files

* small additonal mod doc titles

* safetensor docs

* more core docs

* more module docs in canlde_core

* 2 more link fixes
2024-11-19 04:07:07 +01:00
e86565624b Fix for clippy. (#2626) 2024-11-18 14:32:38 +01:00
386fd8abb4 Module Docs (#2624)
* update whisper

* update llama2c

* update t5

* update phi and t5

* add a blip model

* qlamma doc

* add two new docs

* add docs and emoji

* additional models

* openclip

* pixtral

* edits on the  model docs

* update yu

* update a fe wmore models

* add persimmon

* add model-level doc

* names

* update module doc

* links in heira

* remove empty URL

* update more hyperlinks

* updated hyperlinks

* more links

* Update mod.rs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-18 14:19:23 +01:00
12d7e7b145 More Model Module Docs (#2623)
* dinov2

* add another example

* ad dinov2reg4

* eva2

* efficientvit

* moondream

* update t5

* update t5

* rwkv

* stable diffusion docs

* add wasm link

* add segment_anything

* adjsut for clippy

* ignore bertdoc

* dinov2 ignore

* update block to be text

* remove the rust blocks for the moment

* bump python to 3.11

* add a setup-python step

* add py311 to test as well
2024-11-17 20:27:24 +01:00
a3f200e369 Module Docs (#2620)
* update bert docs

* update based

* update bigcode

* add pixtral

* add flux as well
2024-11-16 09:09:17 +01:00
00d8a0c178 Remove some unused macros. (#2618)
* Remove some unused macros.

* More unused fixes.
2024-11-15 16:46:55 +01:00
f689ce5d39 Documentation Pass for Models (#2617)
* links in chinese_clip

* links for clip model

* add mod docs for flux and llava

* module doc for MMDIT and MIMI

* add docs for a few more modesl

* mod docs for bert naser and beit

* add module docs for convmixer colpali codegeex and chatglm

* add another series of moddocs

* add  fastvit-llama2_c

* module docs mamba -> mobileone

* module docs from moondream-phi3

* mod docs for quantized and qwen

* update to yi

* fix long names

* Update llama2_c.rs

* Update llama2_c_weights.rs

* Fix the link for mimi + tweaks

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-15 08:30:15 +01:00
06350c31c7 Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
2024-11-12 17:10:12 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
3fba2b5fc4 Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models.

* More SmolLM2 support.
2024-11-03 17:11:12 +01:00
530ab96036 Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-01 18:10:40 +01:00
7ac0de15a9 Lazy upcasting for t5. (#2589) 2024-10-30 18:08:51 +01:00
d232e132f6 Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
2024-10-30 06:19:07 +01:00
37e0ab8c64 Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
2024-10-27 10:01:04 +01:00
a2e9d41b20 use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572) 2024-10-23 20:07:09 +02:00
3d1dc06cdb Enable stable-diffusion 3 on metal. (#2560) 2024-10-14 08:59:12 +02:00
f553ab5eb4 Adds support for Stella_en_v5 embedding model - 1.5B variant (#2551)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README
2024-10-13 23:09:12 +02:00
41ade774e8 fix: Allow marian configs to deserialize from json. (#2556) 2024-10-13 23:05:50 +02:00