* Add some fast Metal MLX SDPA kernels (#32)
* Sketch the sdpa kernel
* Add full sdpa kernel,
* Add test
* Add vectorized kernel for decoding
* Update tests
* Add some docs
* Fix sdpa_vector names
* Add softcapping for vectorized sdpa
* Add softcapping for full sdpa
* Add support for head dim 32, 96, 256
* Add support for head dim 32, 96, 256
* Update docs
* Add update notice
* Clippy and format
* Conditional compilation for bf16
* Use it in quantized llama
* Some review comments
* Use set_params!
* Remove unused
* Remove feature
* Fix metal sdpa for v stride
* Remove comma
* Add the dim method to layout and shape.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Stable diffusion 3.5 support.
* Clippy fixes.
* CFG fix.
* Remove some unnecessary clones.
* Avoid duplicating some of the code.
* Release the mmdit model earlier to reduce memory usage.
* Stella_en_1.5B_v5
* Separated creation. This is a critical step for numerical accuracy and would be documented in the readme
* EmbedDim would require clone and copy
* WIP: example
* Examples added
* a litte more in README
* WIP: ONNX Reduce-max ops
* WIP: tests for ReduceMin
* Reduce min/ max v18+
* Reformatting tests for better review readability
* Error on empty set, backward compatibility (13 and below) with 'axes'
* Stella_en_1.5B_v5
* Separated creation. This is a critical step for numerical accuracy and would be documented in the readme
* EmbedDim would require clone and copy
* WIP: example
* Examples added
* a litte more in README
* Add stable diffusion 3 example
Add get_qkv_linear to handle different dimensionality in linears
Add stable diffusion 3 example
Add use_quant_conv and use_post_quant_conv for vae in stable diffusion
adapt existing AutoEncoderKLConfig to the change
add forward_until_encoder_layer to ClipTextTransformer
rename sd3 config to sd3_medium in mmdit; minor clean-up
Enable flash-attn for mmdit impl when the feature is enabled.
Add sd3 example codebase
add document
crediting references
pass the cargo fmt test
pass the clippy test
* fix typos
* expose cfg_scale and time_shift as options
* Replace the sample image with JPG version. Change image output format accordingly.
* make meaningful error messages
* remove the tail-end assignment in sd3_vae_vb_rename
* remove the CUDA requirement
* use default_value in clap args
* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime
* resolve clippy errors and warnings
* use default_value_t
* Pin the web-sys dependency.
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* start to impl chinese clip
* impl vision model
* copy code from bert
* refactor use
* refactor use again
* fix text model
* refactor
* try to fix text model
* tuning
* tuning chinese clip
* delete useless code
* revert code
* Clippy fixes.
* Also apply cargo fmt.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* add bert for masked lm
* working example
* add example readme
* Clippy fix.
* And apply rustfmt.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
* add: direction for lstm layer
* lint: remove unused Error import
* refactor: remove unnecessary int assignment to Direction enum:
* refactor: use &'static str type instead of String for direction_str:
* Run cargofmt.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add Pixtral.
* More pixtral vision encoder.
* Sketch a pixtral example.
* Sketch a pixtral example.
* Better image loading.
* Support loading images embedded in safetensor files.
* Clippy fixes.
* Add the llava multimodal adapter.
* Add more of the llava bits.
* Add the pixtral config.
* More pixtral inference.
* Add the text generation bits.
* Get the example to work.
* Bugfix.
* Run some bits of the model in f32.
* Blessed version :)
* Better rope frequency computations.
* README update.