0e2c8c17fb
UG metal integration. ( #2580 )
2024-10-27 15:20:37 +01:00
594d984f9c
Support for UG kernels. ( #2579 )
...
* Support for UG kernels.
* Add a dedicated test.
2024-10-27 13:37:19 +01:00
dcd83336b6
Testcases ( #2567 )
2024-10-17 13:00:45 +02:00
e4a96f9e7c
Switch to using the MLX matmul by default. ( #2547 )
2024-10-06 23:24:55 +02:00
6faecaa616
Fix for cudnn bf16 conv2d. ( #2535 )
2024-10-02 23:18:55 +02:00
7b60bda4ed
Add support for cuda streams. ( #2532 )
2024-10-02 21:30:58 +02:00
a2bcc227df
Efficient implementation of Tensor::ones()
for metal
( #2512 )
...
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
2024-10-01 19:11:59 +02:00
def4c6cdee
Cuda quantized mmv bugfix. ( #2526 )
2024-10-01 12:57:55 +02:00
724650446c
Yet another cuda qmm padding fix. ( #2509 )
2024-09-30 21:53:30 +02:00
844d45cde4
Bugfix for the metal elu kernel. ( #2490 )
...
* Bugfix for the metal elu kernel.
* Add a test.
2024-09-21 15:03:19 +02:00
af2104078f
Metal commands refactoring ( #2489 )
...
* Split out the commands part of the metal device.
* Make most fields private.
* Move the allocator back.
* Rework the encoder provider type.
2024-09-21 13:18:42 +02:00
382c6b51af
Improve error message ( #2485 )
2024-09-20 07:11:41 -06:00
6eea45a761
Add a couple cast metal kernels. ( #2479 )
2024-09-15 22:27:46 +02:00
ebf722b446
Export TensorIndexer public to candle users ( #2477 )
2024-09-13 22:21:57 +02:00
b60faebea4
Missing metal kernels. ( #2474 )
2024-09-12 13:58:50 +02:00
72d649058b
Hook the MLX matmul kernels in candle-core. ( #2473 )
2024-09-12 13:52:59 +02:00
afb6575835
Use the new MLX kernels to handle the BF16 matmul. ( #2470 )
2024-09-11 17:34:05 +02:00
13b2a8a4a0
Complete the missing backticks in the comments ( #2469 )
2024-09-11 16:37:05 +02:00
aafa24ed93
Update cudarc to 0.12. ( #2451 )
...
* Update cudarc to 0.12.
* Some cudnn tweaks.
2024-08-27 10:10:30 +02:00
736d8eb752
Stream tensor ( #2429 )
...
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).
* Forces u to be strictly positive.
* Add StreamTensor.
2024-08-17 21:54:28 +02:00
7cff5898ec
Support Minus(u) for arbitrary values of u, e.g. Minus(3). ( #2428 )
...
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).
* Forces u to be strictly positive.
2024-08-17 21:29:01 +02:00
d3fe989d08
Add documentation examples for Tensor::i
and Tensor::narrow
methods ( #2308 )
...
* Add documentation examples for `Tensor` methods
* Apply fmt.
* Cosmetic tweaks.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2024-08-10 08:11:09 +02:00
c0a559d427
optimize gradient for silu a bit ( #2393 )
2024-08-04 11:24:17 +02:00
0fcb40b229
Revert the bf16 gemm metal changes for now. ( #2386 )
2024-08-01 23:08:47 +02:00
d4b6f6eef6
Add a minimal test for the metal bf16 matmul. ( #2381 )
2024-08-01 11:22:46 +02:00
957d604a78
Enable BF16 on metal. ( #2380 )
2024-08-01 11:05:07 +02:00
ce90287f45
Add get_ids to GradStore ( #2379 )
2024-08-01 10:56:13 +02:00
1ba87a9450
Use BF16 on metal when possible. ( #2378 )
2024-08-01 10:48:58 +02:00
bd80078acf
Fix log_sum_exp to handle large positive/negative inputs ( #2367 )
2024-08-01 10:37:02 +02:00
8696cf6494
Enable the affine kernel for u8/u32. ( #2376 )
2024-08-01 10:03:11 +02:00
0f5cbb08b3
Add support for Llama 3.1 ( #2359 )
...
* Add Llama 3.1 rope
* Clippy
* Format
* Clippy
* Add support for multiple eos tokens:
* Untagged either
* Remove either dep and fix settings.json
* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
f25173d68b
Fix for backprop in ConvTranspose2D with stride of 2 ( #2337 )
...
* Add gradient test for conv_transpose2d with stride of 2.
* Swap dilation and stride in ConvTranspose2D backpropagation.
Without this, a shape mismatch occurs with a stride of 2 and dilation of 1.
* Add further tests of the ConvTranspose2D gradient.
Values calculated with torch, minor numerical errors adjusted and commented.
2024-07-17 19:22:23 +02:00
6a4741bbf9
Fix Elu gradient NaN on large input ( #2328 )
...
* Fix Elu gradient NaN on large input
* Reuse previously computed exp in Elu
2024-07-16 14:41:16 +02:00
25960676ca
Add a basic metal example with capture ( #2324 )
...
* Add some tracing.
* Get the trace to work.
2024-07-09 12:38:11 +02:00
6baa1d486b
Fix a bug in the metal implemtation of col2im1d. ( #2284 )
2024-06-22 23:21:20 +02:00
36cf54525d
Fix the fast bf16 gemm cublas kernels. ( #2274 )
...
* Use flash-attn in gemma.
* Fix for the fast bf16 cublas gemm.
* Fix some clippy lints.
* Fix another lint.
* Proper clippy fix.
2024-06-18 23:46:58 +02:00
9182c828e6
Automatically upcast for to_u64 ( #2244 )
2024-06-04 11:32:36 +02:00
1ec3b2cc18
add where_cond f32 for metal ( #2236 )
2024-06-02 14:30:06 +02:00
0814dfd148
Add a metal kernel for col2im1d. ( #2214 )
...
* Add a metal kernel for col2im1d.
* Enable the col2im variant.
* Bugfix.
* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
1df2bddccf
Add the layernorm specialized op. ( #2212 )
...
* Add the layernorm cuda kernels.
* Dedicated layer norm op.
* Add the slower variant.
* Plug the cuda implementation.
* Add the metal variant.
* Add a dedicated test.
* Bugfix.
2024-05-24 15:58:01 +02:00
6f0b807ffd
More efficient cuda implementation for ConvTranspose1d. ( #2211 )
...
* More efficient cuda implementation for ConvTranspose1d.
* Small tweak.
2024-05-24 11:05:43 +02:00
01545f7303
Add a slice_set op. ( #2193 )
...
* Add a slice_set op.
* Add some testing.
* Add the dedicated kv-cache module.
* Derive debug and clone.
* Expose more kv-cache functions.
* Return the current data when appending.
* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
21f82a5155
Add SliceSafetensors. ( #2179 )
...
* Add SlicedSafetensors.
* And add some testing.
2024-05-11 13:15:42 +02:00
9cff7bc3f4
Make it possible to use TF32 accumulation in F32 matmuls. ( #2178 )
...
* Allow the use of tf32 accumulation in matmul.
* Better timings.
* Dummy versions for use when cuda is not enabled.
2024-05-11 12:28:39 +02:00
01794dc16e
Use write rather than try-write on the metal rw-locks. ( #2162 )
2024-05-05 07:22:46 +02:00
b13a82a438
Separate quantized phi-3 implementation. ( #2157 )
...
* Separate quantized phi-3 implementation.
* Integrate the quantized phi3 model.=
* Small fixes, get the generation to work properly.
* Keep the old llama implementation around.
* Change the default.
2024-05-04 10:14:57 +02:00
89f53b9d7b
Bump the version number to 0.5.1. ( #2155 )
...
* Bump the version number to 0.5.1.
* Fix clippy lints for 1.78.
* More clippy fixes.
2024-05-03 11:17:05 +02:00
fa06f5f5f9
F16/BF16 bugfix (bis). ( #2143 )
...
* F16/BF16 bugfix (bis).
* Another fix.
* Yet another fix.
2024-04-29 14:08:44 +02:00
09d4845aa8
Bugfix the recent f16/bf16 changes. ( #2142 )
2024-04-29 13:30:11 +02:00
a0d03aded1
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. ( #2124 )
...
* When converting a tensor to a variable, clone if the tensor is already a variable.
* Add a test to ensure training a batch norm works with VarMaps
---------
Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local >
2024-04-29 11:21:53 +02:00