Commit Graph

215 Commits

Author SHA1 Message Date
56aacb05da Make the RNN configs accessible from the models. (#2541) 2024-10-04 14:22:23 +02:00
aa35bf2ff5 Add/lstm direction (#2455)
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-30 22:44:07 +02:00
683ab698de Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
d01207dbf3 Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
31a1075f4b onnx: implement LSTM op (#2268)
use candle-nn LSTM
2024-08-19 09:06:17 +02:00
6991a37b94 update: LSTMState and GRUState fields to be public (#2384) 2024-08-01 16:30:32 +02:00
0f5cbb08b3 Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
a0d03aded1 Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
2024-04-29 11:21:53 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
e5c8b88f90 Apply the cast before the scaling. (#2135) 2024-04-28 08:30:35 +02:00
b2e816752b Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
2024-04-22 18:52:00 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
2ac302a5d1 Add the rope THD kernel. (#2014)
* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
2024-04-05 08:32:58 +02:00
318d143224 Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels.

* Ensure contiguity for RNNs.

* Unrelated fix for segment anything.

* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00
cd6b9e317c Add benchmarks for the candle-nn package (#1995)
* add benchmarks for the candle-nn package

* uncomment test

* format
2024-04-03 07:03:54 +02:00
5522bbc57c Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
2024-04-01 12:10:08 +02:00
60676780a9 Fix detail in new RoPE implementation (#1935) 2024-03-25 18:20:09 +01:00
e7f8e72588 Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
2024-03-25 09:11:20 +01:00
1b98f84a2b Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.

* Add a test for the fast CPU kernel.

* Rope cuda bindings.

* Cuda kernel.

* Metal kernel (part 1).

* Cuda kernels.

* Finish the metal kernel.

* Use the new kernels in the quantized example.

* Fix warning.
2024-03-24 22:48:52 +01:00
0fddec762e RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
2024-03-21 09:48:56 +01:00
af7f8b87d3 Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel.

* CPU implementation for rms-norm.

* Cuda wrappers.

* Add some validation.

* Add some testing.

* More testing.
2024-03-21 06:36:28 +01:00
ce9fbc3682 Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
2024-03-17 10:49:13 +01:00
758366160e add clone to candle dropout (#1814) 2024-03-08 08:18:01 +01:00
0c09d10f32 Improve metal buffer usage (#1807)
* Improve metal buffer usage

* Clone cpu storage when loading to reduce wait_until_complete calls
* Use powers of two for buffer sizes so reuse is more likely.
* Select best available buffer by size.
* Add count to MetalStorage -> can use buffer with different size

Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>

* Simplify new buffer creation without blit copy. Revert &[] -> Vec

* Add documentation on newBufferWithBytes safety / synchronization

* Drop unused buffers after command buffer is done syncing.

---------

Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>
2024-03-07 09:42:34 +01:00
4fd00b8900 Add the StarCoder2 model. (#1779)
* Add the StarCoder2 model.

* Add the example code and get things to work.

* And also tweak the readme.
2024-02-28 21:02:41 +01:00
0c49e95dfb Encodec model. (#1771)
* Encodec model.

* Fixes.

* Add the padding functions.

* Get the LSTM bit to work.

* Get the encodec model to generate some tokens (decoder only for now).

* Minor tweak.

* Minor tweak.
2024-02-27 22:59:40 +01:00
1a6043af51 Tweak the VarMap set type. (#1758) 2024-02-25 20:50:08 +01:00
c753f72c85 Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit.

* Fix the cuda tests.
2024-02-22 09:35:28 +01:00
3ba37443e5 Bugfix for applying the bias in conv1d-transpose. (#1732) 2024-02-18 22:51:20 +01:00
1fb728772d Support for groups in conv-transpose1d. (#1731)
* Groups support in conv-transpose-1d.

* Remove dangling file.
2024-02-18 21:28:07 +01:00
678d44a7f6 Expose the weights and biases in transposed convolutions. (#1727) 2024-02-18 10:35:01 +01:00
41416d2376 Expose more conv1d functions/structs. (#1726) 2024-02-17 18:50:55 +01:00
b60064780d feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
ad73e93da2 Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
2024-02-13 14:26:32 +01:00
020a979de2 Fix clippy lints for 1.76. (#1682) 2024-02-08 16:48:47 +01:00
b75e8945bc Enhance pickle to retrieve state_dict with a given key (#1671) 2024-02-06 21:17:33 +01:00
a90fc5ca5a Add VarBuilder::from_backend (#1670)
`candle-nn` already exposes a trait to define custom backends. However,
it's not possible to actually construct a `VarBuilder` with a custom
backend because the constructor is not exposed.

This change makes the constructor public and renames it from `new` to
`from_backend` to avoid that it is seen as the primary
constructor (which could be confusing to users).
2024-02-06 15:26:11 +01:00
403680f17d Quantized GGUF style (#1523)
* Metal quantized modifications proposal.

- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).

* Cleanup.

* Fix the rebase.

* Removing the fences speeds everything up and *is* correct this time...

* Cleanup the fence.

* After rebase.

* Bad code removal.

* Rebase after phi2 merge + fix replit default to CPU.

* Making the CI happy.

* More happy tests.

---------

Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
2024-01-17 10:27:58 +01:00
539ead927a Update the Phi model to use the updated architecture. (#1580)
* Update the Phi model to use the updated architecture.

* Add more of the phi model.

* Repeat KV + caching.

* Apply the rotary embeddings.

* Add support for the new phi model in the phi example.

* Fix a couple glitches.

* Fix a couple more glitches.
2024-01-13 17:38:27 +01:00
b4cb982e49 Simplifying our internal cargo dependencies. (#1529) 2024-01-07 12:04:14 +01:00