c12ad45562
Add a KV cache to marian decoding. ( #1226 )
2023-10-31 08:47:44 +00:00
7d0202710b
Instructions for generating the tokenizer configs for marian-mt. ( #1225 )
2023-10-31 07:56:26 +01:00
392a00a147
Add support for the marian base model. ( #1221 )
2023-10-30 19:20:36 +00:00
4c967b9184
Use the hub files for the marian example. ( #1220 )
...
* Use the hub files for the marian example.
* Use the secondary decoder.
* Add a readme.
* More readme.
2023-10-30 17:29:36 +00:00
c05c0a8213
PyO3: Add equal
and __richcmp__
to candle.Tensor
( #1099 )
...
* add `equal` to tensor
* add `__richcmp__` support for tensors and scalars
* typo
* more typos
* Add `abs` + `candle.testing`
* remove duplicated `broadcast_shape_binary_op`
* `candle.i16` => `candle.i64`
* `tensor.nelements` -> `tensor.nelement`
* Cleanup `abs`
2023-10-30 15:17:28 +00:00
969960847a
Bugfixes for marian-mt. ( #1219 )
...
* Bugfixes for marian-mt.
* Apply the final decoding head.
* More fixes.
2023-10-30 11:44:19 +00:00
5fc66bd4ba
Support negative steps in arange. ( #1218 )
2023-10-30 07:40:54 +00:00
174b208052
PyO3: Better shape handling ( #1143 )
...
* Negative and `*args` shape handling
* Rename to `PyShapeWithHole` + validate that only one hole exists
* Regenerate stubs
---------
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com >
2023-10-29 15:41:44 +00:00
154c674a79
Add i64-abs. ( #1216 )
2023-10-29 15:28:53 +00:00
7bbde55c61
Marian MT model ( #1210 )
...
* Skeleton files for the marian MT model.
* Marian initialization.
* Implement the attention forward method.
* Forward pass for the encoder side.
* Expose the encoder and decoder.
* Start plugging the decoder.
* Forward pass for the decoder layer.
* Set up the marian example.
* Add some missing backtraces.
* Bugfix.
2023-10-29 15:12:22 +00:00
c3f2676d49
PyO3: Add CI to build & upload wheels as artifacts. ( #1215 )
...
* Add maturin ci
* fix paths
* Change sdist path
2023-10-29 13:44:05 +00:00
46d6566c99
Fix the conv2d gradient computation. ( #1214 )
2023-10-29 09:50:04 +00:00
55bc3382cf
Allow for different behavior between training and eval ( #1213 )
...
* Forward with training.
* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
dece37c6f4
feat: implement VGG13, VGG16 and VGG19 ( #1211 )
...
* feat: implement VGG13, VGG16 and VGG19
* Cosmetic fixes.
* More cosmetic tweaks + avoid re-loading the weights on each final layer.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-10-29 06:10:23 +00:00
498c50348c
Add DDPG and fix Gym wrapper ( #1207 )
...
* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated
* Add DDPG
* Apply fixes
* Remove Result annotations
* Also remove Vec annotation
* rustfmt
* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)
---------
Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com >
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-10-28 19:53:34 +01:00
012ae0090e
Infer the config for llama2-c. ( #1208 )
2023-10-28 19:00:39 +01:00
95a857cf57
Move the llama2-c model in transformers. ( #1205 )
2023-10-28 16:51:19 +01:00
612f5b8156
Make more models cloneable. ( #1203 )
2023-10-28 07:43:08 +01:00
ef33df7ae2
No need for the even constraint on vecdot-q40-q80. ( #1202 )
2023-10-28 07:23:59 +01:00
c8face3f95
Add the relu2 and relu6 activations. ( #1201 )
2023-10-27 20:51:16 +01:00
85bea43e5b
Make the whisper model cloneable ( #1200 )
...
* Add a quantized variant of llama2.c
* Clippy fixes.
* Make the whisper model cloneable.
2023-10-27 16:59:19 +01:00
b3181455d5
Add fuse-conv-bn method for Conv2d ( #1196 )
...
* Add fuse-conv-bn method for Conv2d
* no unwrap
* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
e2826e70b3
Add a quantized variant of llama2.c ( #1197 )
...
* Add a quantized variant of llama2.c
* Clippy fixes.
2023-10-27 15:34:06 +01:00
916619f70b
Minor cleanup ( #1194 )
...
* Add some missing backtraces.
* Small cleanup.
2023-10-27 14:08:29 +01:00
9b1158b315
Add some missing backtraces. ( #1193 )
2023-10-27 06:09:11 +01:00
70d06ab4b0
Add support for the phi-hermes finetuned model. ( #1192 )
2023-10-27 05:57:08 +01:00
0ec5ebcec4
Use the hub model file when possible. ( #1190 )
...
* Use the hub model file when possible.
* And add a mention in the main readme.
2023-10-26 20:00:50 +01:00
c8e197f68c
Fixes for jina-bert. ( #1189 )
2023-10-26 18:52:30 +01:00
5f20697918
Add the jina-bert embeddings model. ( #1187 )
...
* Add the jina-bert model.
* Use alibi.
* Remove the unused pragma.
* Recompute the alibi embeddings.
* Generate the token type ids.
* Use the module trait.
* Add the jina-bert example.
* DType fix.
* Get the inference to work.
2023-10-26 16:54:36 +01:00
e37b487767
Add Blip to online demos README.md ( #1184 )
...
* Add Blip to online demos README.md
* Punctuation.
---------
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com >
2023-10-26 11:07:01 +01:00
e5dc8cb4f4
[Wasm] BLIP Example ( #1183 )
...
* blip wasm start
* fix dependency issue, move token stream here
* vanilla js worker
* roll back vscode
* spell
2023-10-26 07:24:02 +01:00
e7b886d56f
Add a link to the optimisers crate. ( #1180 )
2023-10-25 21:51:45 +01:00
6a446d9d73
convert pytorch's tensor in Python API ( #1172 )
...
* convert pytorch's tensor
* separate tests for convert pytorch tensor
2023-10-25 19:39:14 +01:00
0acd16751d
Expose the fields from batch-norm. ( #1176 )
2023-10-25 15:35:32 +01:00
c698e17619
Enable the test for meshgrid + fix the implementation. ( #1175 )
2023-10-25 13:47:54 +01:00
e4c9adfdbe
Implemented meshgrid ( #1174 )
...
* Implemented meshgrid
* Resolved feedback from LaurentMazare
* Rustfmt
* Updated docstring
* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
b6053b938b
[Wasm] Add puffin phi model to wasm ( #1166 )
...
* load config from file, add puffin phi links
* format
* add prompt examples
2023-10-25 07:09:03 +01:00
45dbe541bc
fix ucopy for f64
tensors ( #1170 )
2023-10-24 17:06:03 +01:00
7bd0faba75
Add support for accelerate in the pyo3 bindings. ( #1167 )
2023-10-24 06:34:37 +01:00
807e3f9f52
derivative for GELU ( #1160 )
...
* derivative for GELU
* add tests
2023-10-23 20:23:45 +01:00
eae94a451b
PyO3: Add mkl
support ( #1159 )
...
* Add `mkl` support
* Set `mkl` path on linux
2023-10-23 20:10:59 +01:00
86e1803191
Add Binary Cross Entropy With Logit Loss to nn crate ( #1157 )
...
* add bce with logit loss
* add bce with logit loss
* remove imports
* fix tiny bug
* add test documentation and refactor function
* fix test cases and formatting
2023-10-23 17:12:44 +01:00
25c3cc4149
Mention the flash-attention restriction in the readme. ( #1158 )
2023-10-23 10:26:56 +01:00
a11af79e23
Add a quantized blip model. ( #1155 )
...
* Add a quantized blip model.
* Integrate the quantized blip model to the actual example.
2023-10-22 20:33:25 +01:00
8a82d623e5
Handle LongStorage in pytorch checkpoints. ( #1152 )
2023-10-22 18:34:36 +01:00
df2f89b6cf
Add some KV cache to blip. ( #1150 )
...
* Add some KV cache to blip.
* Mention BLIP in the readme.
2023-10-22 09:44:48 +01:00
62fc965617
Expose the track-op method. ( #1148 )
2023-10-22 06:57:03 +01:00
5b32c2a41e
Remove the unused pragma and properly apply the bias. ( #1147 )
2023-10-22 06:47:40 +01:00
3115fe42e4
Blip attention mask + readme ( #1146 )
...
* Add the attention mask to the blip model.
* Add a readme.
2023-10-21 22:44:13 +01:00
2531b13bf8
Blip fixes ( #1145 )
...
* Some fixes for the blip example.
* Stop generating on sep tokens.
* Clippy fixes.
* rustfmt.
2023-10-21 21:34:48 +01:00