eb24875856
Reworked affine and it works ? No idea how it's different.
2023-11-08 02:37:20 +01:00
3f662e54cd
Reworked affine and it works ? No idea how it's different.
2023-11-08 02:34:08 +01:00
480a3e22e6
Adding cast + binary kernels.
2023-11-07 23:45:53 +01:00
0c24a885a6
Updated everything and output a trace.
2023-11-07 21:12:42 +01:00
76d3116f5d
Broken metal ?
2023-11-07 14:20:13 +01:00
63cce76b84
Improve metal kernel loading and associated errors
2023-11-06 09:48:18 +01:00
634a4e7168
BlitEncoder added to affine for copying buffer contents quickly.
2023-11-06 08:23:36 +01:00
8124d1003f
Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration
2023-11-06 04:46:56 +01:00
6d4c8c0707
Use metal encode_gemm
2023-11-06 03:27:22 +01:00
c921cc3784
Add Arc to metalstorage buffer for quick cloning
2023-11-04 09:03:23 +01:00
0794e70a19
Debugging index_add.
2023-11-03 12:09:05 +01:00
f57e3164ae
Implemented cos for now.
2023-11-03 01:24:51 +01:00
7161002a34
Finished scaffolding, lots of TODOs
...
- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.
Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.
Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).
Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
2023-11-02 15:32:28 +01:00
82cce52e73
Rename candle-metal -> candle-metal-kernels
2023-11-02 09:53:29 +01:00
71fcb31873
Owned command buffer now.
2023-11-01 18:03:53 +01:00
198009453a
Matmul (no batch, no strided, f32, f32 only) sort of done.
2023-11-01 17:36:51 +01:00
492d164235
More scaffolding, now need to implement matmul (for precompute_cos_sin to work).
2023-11-01 16:54:09 +01:00
2d84c16fed
First pass (Quantized scaffolding work done + quantized example scaffolding).
2023-11-01 15:10:11 +01:00
4525b7b52a
Initial setup
2023-10-31 18:09:10 +01:00
c05c0a8213
PyO3: Add equal
and __richcmp__
to candle.Tensor
( #1099 )
...
* add `equal` to tensor
* add `__richcmp__` support for tensors and scalars
* typo
* more typos
* Add `abs` + `candle.testing`
* remove duplicated `broadcast_shape_binary_op`
* `candle.i16` => `candle.i64`
* `tensor.nelements` -> `tensor.nelement`
* Cleanup `abs`
2023-10-30 15:17:28 +00:00
5fc66bd4ba
Support negative steps in arange. ( #1218 )
2023-10-30 07:40:54 +00:00
154c674a79
Add i64-abs. ( #1216 )
2023-10-29 15:28:53 +00:00
7bbde55c61
Marian MT model ( #1210 )
...
* Skeleton files for the marian MT model.
* Marian initialization.
* Implement the attention forward method.
* Forward pass for the encoder side.
* Expose the encoder and decoder.
* Start plugging the decoder.
* Forward pass for the decoder layer.
* Set up the marian example.
* Add some missing backtraces.
* Bugfix.
2023-10-29 15:12:22 +00:00
46d6566c99
Fix the conv2d gradient computation. ( #1214 )
2023-10-29 09:50:04 +00:00
55bc3382cf
Allow for different behavior between training and eval ( #1213 )
...
* Forward with training.
* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
ef33df7ae2
No need for the even constraint on vecdot-q40-q80. ( #1202 )
2023-10-28 07:23:59 +01:00
e2826e70b3
Add a quantized variant of llama2.c ( #1197 )
...
* Add a quantized variant of llama2.c
* Clippy fixes.
2023-10-27 15:34:06 +01:00
9b1158b315
Add some missing backtraces. ( #1193 )
2023-10-27 06:09:11 +01:00
c698e17619
Enable the test for meshgrid + fix the implementation. ( #1175 )
2023-10-25 13:47:54 +01:00
e4c9adfdbe
Implemented meshgrid ( #1174 )
...
* Implemented meshgrid
* Resolved feedback from LaurentMazare
* Rustfmt
* Updated docstring
* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
45dbe541bc
fix ucopy for f64
tensors ( #1170 )
2023-10-24 17:06:03 +01:00
807e3f9f52
derivative for GELU ( #1160 )
...
* derivative for GELU
* add tests
2023-10-23 20:23:45 +01:00
8a82d623e5
Handle LongStorage in pytorch checkpoints. ( #1152 )
2023-10-22 18:34:36 +01:00
62fc965617
Expose the track-op method. ( #1148 )
2023-10-22 06:57:03 +01:00
e8f760ee44
Add get_on_dim. ( #1142 )
2023-10-21 15:01:38 +01:00
87eb1658e1
Add pad_with_same. ( #1127 )
...
* More model cloning.
* More cloning on quantized models.
* Add pad-with-same.
* Add some tests.
2023-10-18 23:13:37 +01:00
662c186fd5
Better error message when overflowing in narrow. ( #1119 )
2023-10-18 08:40:14 +01:00
6c588c4792
Refactor the pth tensor exctraction. ( #1109 )
2023-10-16 18:16:34 +01:00
0106b0b04c
Read all the tensors in a PyTorch pth file. ( #1106 )
2023-10-16 13:50:07 +01:00
b73c35cc57
Improve the reshape error messages. ( #1096 )
...
* Improve the reshape error messages.
* Add the verbose-prompt flag to the phi example.
2023-10-15 10:43:10 +01:00
8f310cc666
Avoid trying to backprop through non-differentiable layers. ( #1094 )
2023-10-14 22:03:41 +01:00
9309cfc47d
Create a new curand instead of reseeding. ( #1089 )
2023-10-14 10:03:59 +01:00
7473c4ceca
Fix the npy read function and add some testing. ( #1080 )
2023-10-12 15:25:05 +02:00
37dbbff261
Use full tensors for zeros and ones ( #1071 )
...
* Only optimize float tensors.
* Use full tensors for zeros and ones.
2023-10-11 08:16:04 +01:00
9fea56d28e
Only optimize float tensors. ( #1069 )
2023-10-10 09:05:41 +01:00
b34d7f0248
Remove some unusued bits. ( #1067 )
2023-10-09 19:49:57 +01:00
9abeddd750
Make the cuda rng seedable. ( #1056 )
2023-10-08 09:32:36 +01:00
aa53368aeb
Better control on the optional dequantization in QMatMul ( #1049 )
...
* Cosmetic change to the quantized whisper model.
* Fix the dequantization.
* Add the dequantize all variable.
2023-10-07 10:16:18 +01:00
7f7d95e2c3
Add the round-to function. ( #1039 )
2023-10-05 20:28:09 +01:00
8f7973958c
fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 ( #1037 )
...
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0
* cargo fmt
2023-10-05 18:46:13 +01:00