* Add some metal sort kernels imported from MLX.
* Add another test.
* Start adding the multiblock version.
* Proper kernel names.
* Split out the main metal file.
* Multi-block sort.
* More sorting.
* DType parametrization.
* Add a larger test.
* Update main.rs
* Update codegeex4_9b.rs
* Get things to compile.
* Add some default for when rope_ratio is missing.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Update the stable diffusion example with inpainting support for 1.5, 2 and XL.
* Apply cargo fmt.
* Clippy fixes.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
* unpadded lse added
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* add xlm-roberta-base
* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions
- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.
* Clippy fix.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* Fix bug in whisper transformer
- due to num_threads going to zero
in single threaded case
* Apply rustfmt.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* change: BertEncoder struct to public
* change: make certain fields in Config struct public
* change: all fields in bert config struct to be public
* change: add clone to bert encoder and others
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Adds support for stella_en_v5 embedding model -400M variant
* Unified stella
* WIP: Unified Stella
* Combined stella for both 1.5B and 400M variants
* Cargo fmt for the CI
* removed redundant stella-400m model and example after merge into stella-en-v5
* cargo fmt --all
---------
Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>