Commit Graph

2371 Commits

Author SHA1 Message Date
3769206583 Update docs (#2553)
* add module docs for candle-core

* doc each of the candle-nn modules and add the links to the doc page
2024-11-11 22:13:52 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
6454597943 Improved launch config for layer-norm/rms-norm. (#2591)
* Improved launch config for layer-norm/rms-norm.

* Add more testing for the fused layer/rms norm kernels.
2024-11-04 10:42:18 +01:00
3fba2b5fc4 Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models.

* More SmolLM2 support.
2024-11-03 17:11:12 +01:00
530ab96036 Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-01 18:10:40 +01:00
7ac0de15a9 Lazy upcasting for t5. (#2589) 2024-10-30 18:08:51 +01:00
d232e132f6 Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
2024-10-30 06:19:07 +01:00
139ff56aeb Reduce memory usage for sd 3.5. (#2582) 2024-10-28 22:45:02 +01:00
498bc2cdc9 Release the mmdit model earlier to reduce memory usage. (#2581)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.

* Release the mmdit model earlier to reduce memory usage.
2024-10-28 16:06:53 +01:00
0e2c8c17fb UG metal integration. (#2580) 2024-10-27 15:20:37 +01:00
594d984f9c Support for UG kernels. (#2579)
* Support for UG kernels.

* Add a dedicated test.
2024-10-27 13:37:19 +01:00
37e0ab8c64 Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
2024-10-27 10:01:04 +01:00
07849aa595 Update README.md (#2577) 2024-10-26 18:23:52 +02:00
3699c1a053 Fix the repo name for llama 3.1. (#2576)
* Fix the repo name for llama 3.1.

* Fix the book.
2024-10-26 11:25:04 +02:00
a2e9d41b20 use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572) 2024-10-23 20:07:09 +02:00
7c09215ef4 ONNX: GatherElements, Xor (#2568) 2024-10-17 20:22:35 +02:00
dcd83336b6 Testcases (#2567) 2024-10-17 13:00:45 +02:00
a01aa89799 onnx: ReduceMin/Max Ops (#2563)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README

* WIP: ONNX Reduce-max ops

* WIP: tests for ReduceMin

* Reduce min/ max v18+

* Reformatting tests for better review readability

* Error on empty set, backward compatibility (13 and below) with 'axes'
2024-10-15 10:34:07 +02:00
3d1dc06cdb Enable stable-diffusion 3 on metal. (#2560) 2024-10-14 08:59:12 +02:00
f553ab5eb4 Adds support for Stella_en_v5 embedding model - 1.5B variant (#2551)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README
2024-10-13 23:09:12 +02:00
41ade774e8 fix: Allow marian configs to deserialize from json. (#2556) 2024-10-13 23:05:50 +02:00
6eab6b57f5 Fix the guide to gain access to Stable Diffusion 3 Medium (#2559) 2024-10-13 22:55:26 +02:00
ca7cf5cb3b Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-13 22:08:40 +02:00
0d96ec31e8 feat: intergrate chinese clip and add example (#2555)
* start to impl chinese clip

* impl vision model

* copy code from bert

* refactor use

* refactor use again

* fix text model

* refactor

* try to fix text model

* tuning

* tuning chinese clip

* delete useless code

* revert code

* Clippy fixes.

* Also apply cargo fmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-10-10 15:18:55 +02:00
937e8eda74 Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm

* working example

* add example readme

* Clippy fix.

* And apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-07 23:28:21 +02:00
edf7668291 improve (#2548) 2024-10-07 17:30:56 +02:00
e4a96f9e7c Switch to using the MLX matmul by default. (#2547) 2024-10-06 23:24:55 +02:00
f856b5c3a7 pyo3 update. (#2545)
* pyo3 update.

* Stub fix.
2024-10-06 10:09:38 +02:00
d2e432914e Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example.

* Print all tensors when no argument is provided.
2024-10-05 10:05:14 +02:00
410c89f72a Add required feature for whisper example in Readme (#2539) 2024-10-04 14:29:55 +02:00
56aacb05da Make the RNN configs accessible from the models. (#2541) 2024-10-04 14:22:23 +02:00
6faecaa616 Fix for cudnn bf16 conv2d. (#2535) 2024-10-02 23:18:55 +02:00
90d04ff622 Support whisper large-v3 turbo in the whisper-microphone example. (#2533) 2024-10-02 22:09:14 +02:00
7b60bda4ed Add support for cuda streams. (#2532) 2024-10-02 21:30:58 +02:00
936300678d Add whisper large-v3 turbo to the example. (#2531) 2024-10-02 21:07:08 +02:00
f479840ce6 Add a seed to the flux example. (#2529) 2024-10-02 10:52:02 +02:00
fd08d3d0a4 Tweak some metal tests. (#2528) 2024-10-02 10:22:31 +02:00
a2bcc227df Efficient implementation of Tensor::ones() for metal (#2512)
* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
2024-10-01 19:11:59 +02:00
def4c6cdee Cuda quantized mmv bugfix. (#2526) 2024-10-01 12:57:55 +02:00
888d886dd8 Add ColPali (#2524)
* add colpali

* cleanup

* fix clippy
2024-10-01 11:48:39 +02:00
6110ad8d4f Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example.

* Tweak the whisper microphone example more.
2024-10-01 00:24:17 +02:00
aa35bf2ff5 Add/lstm direction (#2455)
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-30 22:44:07 +02:00
724650446c Yet another cuda qmm padding fix. (#2509) 2024-09-30 21:53:30 +02:00
dfe9a00683 Pixtral polishing. (#2522)
* Pixtral polishing.

* Clippy fix.
2024-09-30 21:23:54 +02:00
683ab698de Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
2f49e1b534 Add PaliGemma. (#2519)
* Add PaliGemma.

* PaliGemma inference loop.

* Running PaliGemma example.

* Tweak the prompt.
2024-09-29 19:56:56 +02:00
0ebb38813b Paligemma siglip vision config (#2518)
* Add the paligemma siglip vision config.

* More paligemma configs.
2024-09-29 17:53:52 +02:00
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 0.7.2 2024-09-29 10:56:50 +02:00
261ed65f36 Add the SigLIP model. (#2515)
* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
2024-09-28 23:48:00 +02:00
62525e8352 Remove some extra whitelines. (#2513) 2024-09-28 14:41:28 +02:00