Compare commits

..

188 Commits

Author SHA1 Message Date
7e49e0af96 Tmp for allocator. 2023-11-16 12:50:41 +01:00
181d2299b2 TMp. 2023-11-16 11:41:06 +01:00
2801541e5f new_owned -> new()..to_owned(). 2023-11-16 11:07:56 +01:00
4289984d32 Remove some prints. 2023-11-13 14:51:40 +01:00
1471f98f0b BF16 metal fix. 2023-11-13 14:44:20 +01:00
dd4a40f1c0 Fixes + cache compute_pipeline_state. 2023-11-13 14:33:16 +01:00
79845bd93b Working version for llama2-c. 2023-11-13 12:36:27 +01:00
6071797450 Add erf. 2023-11-11 18:22:16 +01:00
b58b247323 Putting back f16 index select. 2023-11-11 17:43:35 +01:00
3900091e75 All tests are panicking instead of random failure. 2023-11-11 17:43:35 +01:00
54355ff997 Adding some half kernels. 2023-11-11 17:43:35 +01:00
e02f1912bb Reusing a single buffer (for now) to speed things up. 2023-11-11 17:43:35 +01:00
a52b71686b Going back on remote metal-rs. 2023-11-11 17:43:35 +01:00
7adfb70dff Few fixes. 2023-11-11 17:43:35 +01:00
3ad02147e4 Starting to fix some tests. 2023-11-11 17:43:34 +01:00
4f39695465 Missing new test. 2023-11-11 17:42:53 +01:00
4cf4844c9d Adding the test scaffolding. 2023-11-11 17:27:19 +01:00
d840838e95 Cleanup fixed a few ops removed debugging scaffolding. 2023-11-11 17:18:00 +01:00
61a070fdd1 Debugging rope. 2023-11-11 17:18:00 +01:00
e35669647d Fixed matmul (display still broken without casting back to CPU first? ) 2023-11-11 17:18:00 +01:00
53e8b7ee3e Tmp state. 2023-11-11 17:18:00 +01:00
cc26cce23c Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-11 17:18:00 +01:00
02c2ec2c71 Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-11 17:18:00 +01:00
9a2784b8ab Refactor to simplify our lives for settings the params in the encoder. 2023-11-11 17:18:00 +01:00
0f652f0e3d Adding the actual backend 2023-11-11 17:18:00 +01:00
ddee9dc1dd Remove tracing. 2023-11-11 17:18:00 +01:00
fc9bb7784a Metal part 1 - Scaffolding for metal. 2023-11-11 17:18:00 +01:00
f1e678b39c Mention the Yi-6b/Yi-34b models in the readme. (#1321) 2023-11-11 12:39:11 +01:00
a007f8fdb4 Add the Yi-6b and Yi-34b models. (#1320)
* Add the Yi-6b model.

* Add the 34b model.

* Add the yi example.

* Fix the weight file names.
2023-11-11 12:00:48 +01:00
2341aa079e Fix quantized zephyr chat prompt (#1314) (#1317)
* Fix quantized zephyr chat prompt (#1314)

* Avoid using a mutable variable.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-11 09:14:12 +01:00
9e666d4229 Add the var method. (#1315)
* Add the var method.

* Add a test.
2023-11-10 22:47:57 +01:00
1b12142a02 Add min to buckets in relative_position_bucket (#1312) 2023-11-10 11:57:25 +01:00
d2c3f14773 Fix for flash-attn. (#1310)
Co-authored-by: laurent <laurent@par2dc5-ai-prd-cl01dgx02.cm.cluster>
2023-11-10 10:27:27 +01:00
26c4e5bf1d Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal.

* Remove tracing.
2023-11-10 08:35:48 +01:00
18d30005c5 Add support to UL2 model family (#1300)
* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-09 18:55:09 +01:00
6958384327 Add support for TrOCR Model (#1303)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* add trocr model

* fix formatting

* commit the actual model lol

* more formatting

* remove tokenizer config
2023-11-09 18:49:17 +01:00
e6697471bb Add weight and bias functions to LayerNorm (#1306) 2023-11-09 16:09:01 +01:00
73d02f4f57 fix: negative axis (#1296)
* fix: negative axis

* Use normalize_axis.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-08 23:28:21 +01:00
f772213e84 Fix bug introduced in madlad PR (#1298) 2023-11-08 17:55:46 +01:00
2feb0b054f Add the mel filters for 128 bins. (#1295) 2023-11-08 08:23:53 +01:00
2d28497197 Preliminary support for whisper v3. (#1294)
* Preliminary support for whisper v3.

* Add the missing files.
2023-11-08 06:42:52 +01:00
f3a4f3db76 PyO3: Add optional candle.onnx module (#1282)
* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
2023-11-08 06:37:50 +01:00
7920b45c8a Support for timegroupnorm in encodec. (#1291) 2023-11-07 22:39:59 +01:00
d4a45c936a Quantized model small tweaks (#1290)
* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.

* Tweaks for the quantized model.
2023-11-07 21:21:37 +01:00
c912d24570 Update README: Move T5 to Text to Text section (#1288)
I think it makes more sense to have it there, since it's a seq2seq model with cross attention, and not a LM. There are also Decoder only T5 models that work as LMs, but that's not the standard.
2023-11-07 16:14:04 +01:00
d5c2a7b64b Add info about MADLAD-400 in readme files (#1287) 2023-11-07 15:21:59 +01:00
508f811b93 Add support for MADLAD400 (#1285)
* Add support for madlad

* Add support for quantized MADLAD
2023-11-07 05:35:37 +01:00
a773a4b22b [ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.
2023-11-06 22:44:58 +01:00
5a363dbc26 Adds check for 7b-zephyr and uses correct template (#1283)
* Adds check for 7b-zephyr and uses correct template

* Handle zephyr as mistral.

* Disable the protoc bits of the CI.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-06 21:05:39 +01:00
abc4f698c5 Add candle-sampling (#1278) 2023-11-06 12:53:29 +01:00
a923e8b53a Add a link to candle-ext to README.md (#1277) 2023-11-06 12:44:39 +01:00
2a45bcf943 Put the onnx example behind a feature flag. (#1276)
* Put the onnx example behind a feature flag.

* Exclude the onnx bits from the workspace.

* README tweaks.
2023-11-06 07:45:07 +01:00
47f4ddb011 Added info about missing protoc (#1275)
Co-authored-by: figgefigge <fredric.1337mail.com>
2023-11-06 06:47:32 +01:00
f365a075e5 Add more models to the onnx example. (#1273)
* Add more models to the onnx example.

* Input validation.

* Input validation.

* Bugfix.

* Implement clip.

* BatchNorm support.

* Get the efficientnet onnx to work.
2023-11-05 16:57:26 +01:00
60fdab4e17 Detach all grads during backprop. (#1243)
* Detach all grads during backprop.

* Add an environment variable to select the backprop behavior.

* Update the comment.
2023-11-05 14:07:41 +01:00
928a9d906e [ONNX] Do not generate values for constants. (#1272)
* Do not generate values for constants.

* Add an onnx based example using squeezenet.
2023-11-05 11:23:14 +01:00
d1d89bac1f feat: download cifar dataset parquet files (#1259) 2023-11-05 10:55:49 +01:00
39ad840a90 Better tensor initialization in ONNX. (#1270)
* Better tensor initialization in ONNX.

* MaxPool support.

* Add AvgPool.

* Get the squeezenet example to work.
2023-11-04 22:17:45 +01:00
b5e4f84bed Refactor the onnx attribute getters. (#1268)
* Refactor the onnx attribute getters.

* Add get-attr-opt.

* Add support for convolutions.

* Add support for convolutions.
2023-11-04 21:31:48 +01:00
7051fb8098 feat: add backprop for elu (#1269)
* feat: add backprop for elu

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-04 21:26:41 +01:00
dc68c130e4 Support more ONNX ops. (#1267)
* Add LogSoftmax.

* Support for Transpose.
2023-11-04 15:10:14 +01:00
bc9a1bf239 Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example.

* Fix the casting operation.

* Support more ops.

* Handle reshape.

* Concat.

* Softmax.
2023-11-04 10:02:47 +01:00
f7c957d64f ONNX casting support. (#1265)
* ONNX casting support.

* Handle tensor constants.

* Bugfix the binary ops.
2023-11-04 08:34:24 +01:00
8cbb9d0e6c Add some preliminary ONNX support (#1260)
* Add the onnx protos.

* Move the reading bits.

* Install protoc on the CI.

* Install protoc on the cuda CI too.

* Use clap for the onnx tool.

* Tweak the CI protoc install.

* Add some simple evalution function.

* Add some binary operator support.
2023-11-04 06:36:05 +01:00
bfe95115c6 Update README.md (#1264) 2023-11-04 05:32:32 +01:00
6fa3151820 Allow using gguf-v3 files. (#1262) 2023-11-03 23:07:53 +01:00
0a58886ccb add distil-whisper link (#1261) 2023-11-03 21:34:42 +01:00
3173b1ce3b feat: impl backprop for erf and gelu-erf (#1258)
* impl backprop for erf anf gelu-erf

* feat: unary tests added for erf and gelu-erf

* fix: (clippy) remove immediately dereferenced ref

* fix: improve comments with pytorch code snippet

* fix: adjust comment typo in backprop impl
2023-11-03 21:32:30 +01:00
ad63f20781 add Kalosm to the list of external resources (#1257) 2023-11-03 19:16:46 +01:00
1cfc5d6d0c Backprop support for conv1d (cpu only for now). (#1255) 2023-11-03 14:23:53 +01:00
b07b2350b6 Test for the transposed conv1d. (#1254) 2023-11-03 13:10:28 +01:00
1b5063f3ca Add vllm external resource (#1253) 2023-11-03 12:40:31 +01:00
3b0d1e7d03 Transposed conv1d in candle-nn. (#1252) 2023-11-03 11:18:25 +01:00
be4555c5a5 Add the conv-transpose1d op. (#1251)
* Skeleton structure for conv-transpose1d.

* CPU implementation for conv-transpose1d.
2023-11-03 09:44:46 +01:00
6975c65112 Share the layer-norm implementation. (#1248) 2023-11-03 06:30:05 +01:00
a2a20aeecc Add the swiglu activation from the chatglm PR. (#1246) 2023-11-02 20:01:34 +01:00
e08fbb6543 Add support for distil whisper (#1245)
* Add support for distil-whisper.

* Add distil-large.

* Rename the large model.
2023-11-02 19:32:35 +01:00
d39d0c40fd Add hard-sigmoid and hard-swish activations (#1244)
* Add hard-sigmoid and hard-swish activations

* Update ops.rs

* Use / rather than div.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-02 18:20:27 +01:00
b97463098c llama2-c wasm fix. 2023-11-02 10:31:47 +01:00
fbd69f952c Lazy detach. (#1242) 2023-11-02 07:33:48 +00:00
6c990a33ea Remove the unused pragma for marian. (#1236) 2023-11-01 20:04:52 +00:00
1704f1b3ae Consolidate the with-tracing usage. (#1234) 2023-11-01 18:21:36 +00:00
693fad511c Preliminary support for ssd1b. (#1233) 2023-11-01 14:37:52 +00:00
36fb84f038 Add a hack for generating random uniform/normal for f16/bf16. (#1228) 2023-10-31 20:27:59 +00:00
c12ad45562 Add a KV cache to marian decoding. (#1226) 2023-10-31 08:47:44 +00:00
7d0202710b Instructions for generating the tokenizer configs for marian-mt. (#1225) 2023-10-31 07:56:26 +01:00
392a00a147 Add support for the marian base model. (#1221) 2023-10-30 19:20:36 +00:00
4c967b9184 Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example.

* Use the secondary decoder.

* Add a readme.

* More readme.
2023-10-30 17:29:36 +00:00
c05c0a8213 PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
2023-10-30 15:17:28 +00:00
969960847a Bugfixes for marian-mt. (#1219)
* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
2023-10-30 11:44:19 +00:00
5fc66bd4ba Support negative steps in arange. (#1218) 2023-10-30 07:40:54 +00:00
174b208052 PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-29 15:41:44 +00:00
154c674a79 Add i64-abs. (#1216) 2023-10-29 15:28:53 +00:00
7bbde55c61 Marian MT model (#1210)
* Skeleton files for the marian MT model.

* Marian initialization.

* Implement the attention forward method.

* Forward pass for the encoder side.

* Expose the encoder and decoder.

* Start plugging the decoder.

* Forward pass for the decoder layer.

* Set up the marian example.

* Add some missing backtraces.

* Bugfix.
2023-10-29 15:12:22 +00:00
c3f2676d49 PyO3: Add CI to build & upload wheels as artifacts. (#1215)
* Add maturin ci

* fix paths

* Change sdist path
2023-10-29 13:44:05 +00:00
46d6566c99 Fix the conv2d gradient computation. (#1214) 2023-10-29 09:50:04 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
dece37c6f4 feat: implement VGG13, VGG16 and VGG19 (#1211)
* feat: implement VGG13, VGG16 and VGG19

* Cosmetic fixes.

* More cosmetic tweaks + avoid re-loading the weights on each final layer.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-29 06:10:23 +00:00
498c50348c Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated

* Add DDPG

* Apply fixes

* Remove Result annotations

* Also remove Vec annotation

* rustfmt

* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)

---------

Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-28 19:53:34 +01:00
012ae0090e Infer the config for llama2-c. (#1208) 2023-10-28 19:00:39 +01:00
95a857cf57 Move the llama2-c model in transformers. (#1205) 2023-10-28 16:51:19 +01:00
612f5b8156 Make more models cloneable. (#1203) 2023-10-28 07:43:08 +01:00
ef33df7ae2 No need for the even constraint on vecdot-q40-q80. (#1202) 2023-10-28 07:23:59 +01:00
c8face3f95 Add the relu2 and relu6 activations. (#1201) 2023-10-27 20:51:16 +01:00
85bea43e5b Make the whisper model cloneable (#1200)
* Add a quantized variant of llama2.c

* Clippy fixes.

* Make the whisper model cloneable.
2023-10-27 16:59:19 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
e2826e70b3 Add a quantized variant of llama2.c (#1197)
* Add a quantized variant of llama2.c

* Clippy fixes.
2023-10-27 15:34:06 +01:00
916619f70b Minor cleanup (#1194)
* Add some missing backtraces.

* Small cleanup.
2023-10-27 14:08:29 +01:00
9b1158b315 Add some missing backtraces. (#1193) 2023-10-27 06:09:11 +01:00
70d06ab4b0 Add support for the phi-hermes finetuned model. (#1192) 2023-10-27 05:57:08 +01:00
0ec5ebcec4 Use the hub model file when possible. (#1190)
* Use the hub model file when possible.

* And add a mention in the main readme.
2023-10-26 20:00:50 +01:00
c8e197f68c Fixes for jina-bert. (#1189) 2023-10-26 18:52:30 +01:00
5f20697918 Add the jina-bert embeddings model. (#1187)
* Add the jina-bert model.

* Use alibi.

* Remove the unused pragma.

* Recompute the alibi embeddings.

* Generate the token type ids.

* Use the module trait.

* Add the jina-bert example.

* DType fix.

* Get the inference to work.
2023-10-26 16:54:36 +01:00
e37b487767 Add Blip to online demos README.md (#1184)
* Add Blip to online demos README.md

* Punctuation.

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-26 11:07:01 +01:00
e5dc8cb4f4 [Wasm] BLIP Example (#1183)
* blip wasm start

* fix dependency issue, move token stream here

* vanilla js worker

* roll back vscode

* spell
2023-10-26 07:24:02 +01:00
e7b886d56f Add a link to the optimisers crate. (#1180) 2023-10-25 21:51:45 +01:00
6a446d9d73 convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor

* separate tests for convert pytorch tensor
2023-10-25 19:39:14 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
c698e17619 Enable the test for meshgrid + fix the implementation. (#1175) 2023-10-25 13:47:54 +01:00
e4c9adfdbe Implemented meshgrid (#1174)
* Implemented meshgrid

* Resolved feedback from LaurentMazare

* Rustfmt

* Updated docstring

* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
b6053b938b [Wasm] Add puffin phi model to wasm (#1166)
* load config from file, add puffin phi links

* format

* add prompt examples
2023-10-25 07:09:03 +01:00
45dbe541bc fix ucopy for f64 tensors (#1170) 2023-10-24 17:06:03 +01:00
7bd0faba75 Add support for accelerate in the pyo3 bindings. (#1167) 2023-10-24 06:34:37 +01:00
807e3f9f52 derivative for GELU (#1160)
* derivative for GELU

* add tests
2023-10-23 20:23:45 +01:00
eae94a451b PyO3: Add mkl support (#1159)
* Add `mkl` support

* Set `mkl` path on linux
2023-10-23 20:10:59 +01:00
86e1803191 Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
2023-10-23 17:12:44 +01:00
25c3cc4149 Mention the flash-attention restriction in the readme. (#1158) 2023-10-23 10:26:56 +01:00
a11af79e23 Add a quantized blip model. (#1155)
* Add a quantized blip model.

* Integrate the quantized blip model to the actual example.
2023-10-22 20:33:25 +01:00
8a82d623e5 Handle LongStorage in pytorch checkpoints. (#1152) 2023-10-22 18:34:36 +01:00
df2f89b6cf Add some KV cache to blip. (#1150)
* Add some KV cache to blip.

* Mention BLIP in the readme.
2023-10-22 09:44:48 +01:00
62fc965617 Expose the track-op method. (#1148) 2023-10-22 06:57:03 +01:00
5b32c2a41e Remove the unused pragma and properly apply the bias. (#1147) 2023-10-22 06:47:40 +01:00
3115fe42e4 Blip attention mask + readme (#1146)
* Add the attention mask to the blip model.

* Add a readme.
2023-10-21 22:44:13 +01:00
2531b13bf8 Blip fixes (#1145)
* Some fixes for the blip example.

* Stop generating on sep tokens.

* Clippy fixes.

* rustfmt.
2023-10-21 21:34:48 +01:00
0d9bb4eb18 Add the blip example. (#1144)
* Add the blip example.

* Tweak the example.

* Implement the cross-attn logic.

* Fix some shape mismatches.

* Get some logits out.

* Get some caption to be generated.
2023-10-21 20:05:02 +01:00
e8f760ee44 Add get_on_dim. (#1142) 2023-10-21 15:01:38 +01:00
94e3373883 Blip forward pass (#1141)
* More forward methods for the blip model.

* Blipping continues.
2023-10-21 10:19:23 +01:00
34d9e91748 Add the blip image captioning model (#1140)
* Blip text model.

* Blip vision bits.

* Blippity.

* More blip.
2023-10-20 22:09:11 +01:00
cfb423ab76 PyO3: Add CI (#1135)
* Add PyO3 ci

* Update python.yml

* Format `bert.py`
2023-10-20 19:05:14 +01:00
7366aeac21 Make func cloneable. (#1137) 2023-10-20 16:28:50 +01:00
99cf13e8e2 Add the sequential layer. (#1136) 2023-10-20 16:08:50 +01:00
b43ab6cd1d PyO3: Add None and Tensor indexing to candle.Tensor (#1098)
* Add proper `None` and `tensor` indexing

* Allow indexing via lists + allow tensor/list indexing outside of first dimension
2023-10-20 09:59:00 +01:00
31ca4897bb Readme updates. (#1134) 2023-10-20 09:08:39 +01:00
55351ef57d Add some vision transformers models (#1132)
* Start adding vision-transformers.

* Add self-attn.

* More vision transformers.

* vit-vit.

* Add the actual vit model.

* Add the example code for the vision transformers.
2023-10-19 22:24:18 +01:00
6684b7127a PyO3: Add pytorch like .to() operator to candle.Tensor (#1100)
* add `.to()` operator

* Only allow each value to be provided once via `args` or `kwargs`
2023-10-19 21:46:21 +01:00
93c25e8844 Expose the larger resnets (50/101/152) in the example. (#1131) 2023-10-19 13:48:28 +01:00
cd53c472df Support ResNet 50/101/152. (#1130) 2023-10-19 10:48:31 +01:00
6f76383f38 Add a readme for the resnet example. (#1129) 2023-10-19 09:58:50 +01:00
8e773cc0c6 Experiment with resnet (#1128)
* Add some preliminary support for resnet.

* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
87eb1658e1 Add pad_with_same. (#1127)
* More model cloning.

* More cloning on quantized models.

* Add pad-with-same.

* Add some tests.
2023-10-18 23:13:37 +01:00
902d0b9166 More model cloning. (#1126)
* More model cloning.

* More cloning on quantized models.
2023-10-18 21:55:46 +01:00
185b54a33b Make some model cloneable. (#1125) 2023-10-18 19:30:47 +01:00
620c94d12e Add support for Zephyr-7b in the quantized model. (#1124) 2023-10-18 17:31:26 +01:00
86e7d539d2 Add the quantized mpt model. (#1123)
* Add the quantized mpt model.

* Support the quantized model for replit-code.
2023-10-18 16:29:38 +01:00
cb034506cd Remove the unused pragma in mpt. (#1122) 2023-10-18 15:47:50 +01:00
63c204c79e Add a mention to the replit-code model in the readme. (#1121) 2023-10-18 11:27:23 +01:00
767a6578f1 MPT alibi fixes. (#1120)
* MPT alibi fixes.

* Some more fixes.

* Finally get the model to return some sensible outputs.

* Add a readme.
2023-10-18 10:58:05 +01:00
662c186fd5 Better error message when overflowing in narrow. (#1119) 2023-10-18 08:40:14 +01:00
2cd745a97c MPT fixes. (#1117)
* MPT fixes.

* Another couple fixes.

* Another shape fix.
2023-10-17 21:53:31 +01:00
a72b50e2c0 Build alibi bias. (#1115)
* Build alibi bias.

* Apply the alibi attention bias.

* Add the replit-code example.
2023-10-17 20:41:37 +01:00
872c3f14b0 Add the MPT model. (#1114)
* Add the MPT model.

* Add ffn and block.

* Forward pass for the mpt block.

* Repeat-kv.
2023-10-17 16:06:48 +01:00
f9e93f5b69 Extend stub.py to accept external typehinting (#1102) 2023-10-17 11:07:26 +01:00
b355ab4e2e Always broadcast magic methods (#1101) 2023-10-17 10:57:12 +01:00
2fe24ac5b1 Rework the cuda casting bits. (#1112) 2023-10-17 09:44:51 +01:00
00948eb656 Formatting tweak. (#1111) 2023-10-16 21:02:53 +01:00
af67672207 Add support for Puffin-Phi-v2. (#1110)
* Add support for Puffin-Phi-v2.

* Tweak the file name.

* Support the config for puffin-phi-v2.

* Update the readme.
2023-10-16 20:54:21 +01:00
6c588c4792 Refactor the pth tensor exctraction. (#1109) 2023-10-16 18:16:34 +01:00
122da87580 feat: add pth varbuilder (#1108) 2023-10-16 16:20:36 +01:00
75629981bc feat: parse Cuda compute cap from env (#1066)
* feat: add support for multiple compute caps

* Revert to one compute cap

* fmt

* fix
2023-10-16 15:37:38 +01:00
0106b0b04c Read all the tensors in a PyTorch pth file. (#1106) 2023-10-16 13:50:07 +01:00
588ad4835a Fix the verbose prompt for phi. (#1097) 2023-10-15 10:53:25 +01:00
b73c35cc57 Improve the reshape error messages. (#1096)
* Improve the reshape error messages.

* Add the verbose-prompt flag to the phi example.
2023-10-15 10:43:10 +01:00
8f310cc666 Avoid trying to backprop through non-differentiable layers. (#1094) 2023-10-14 22:03:41 +01:00
8921d5027c Add support for phi-1.0 (#1093)
* Add support for phi-1.0

* Update the readme.
2023-10-14 20:15:43 +01:00
29c7f2565d Add some reinforcement learning example. (#1090)
* Add some reinforcement learning example.

* Python initialization.

* Get the example to run.

* Vectorized gym envs for the atari wrappers.

* Get some simulation loop to run.
2023-10-14 16:46:43 +01:00
9309cfc47d Create a new curand instead of reseeding. (#1089) 2023-10-14 10:03:59 +01:00
a193bf5f60 Another gemm update. (#1088) 2023-10-14 09:36:52 +01:00
2c110ac7d9 Add the pooling operators to the pyo3 layer. (#1086) 2023-10-13 20:18:10 +01:00
75989fc3b7 Use an attention mask in the e5 padding case. (#1085) 2023-10-13 18:53:40 +01:00
07af87a1d8 Typos. (#1084) 2023-10-13 16:21:20 +01:00
eefad2b95f Update to gemm 0.16.1 (#1083) 2023-10-13 06:40:20 +01:00
5e6df4a3f7 Update to gemm-0.16. (#1082)
* Update to gemm-0.16.

* Enable wasm-simd128.
2023-10-12 21:56:59 +01:00
7473c4ceca Fix the npy read function and add some testing. (#1080) 2023-10-12 15:25:05 +02:00
c096f02411 Add a matvec cpu benchmark. (#1076) 2023-10-12 09:29:18 +01:00
e7560443e4 Convmixer example (#1074)
* Add a convmixer based example.

* Mention the model in the readme.
2023-10-11 19:51:10 +01:00
89b525b5e7 Convmixer (#1073)
* Only optimize float tensors.

* Use full tensors for zeros and ones.

* Add a benchmark for the matmul slowness.

* Add the convmixer model.

* Proper adaptive pooling.
2023-10-11 18:24:32 +01:00
37dbbff261 Use full tensors for zeros and ones (#1071)
* Only optimize float tensors.

* Use full tensors for zeros and ones.
2023-10-11 08:16:04 +01:00
9fea56d28e Only optimize float tensors. (#1069) 2023-10-10 09:05:41 +01:00
209 changed files with 20603 additions and 1080 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev -y
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:

BIN
.github/workflows/maturin.yml vendored Normal file

Binary file not shown.

68
.github/workflows/python.yml vendored Normal file
View File

@ -0,0 +1,68 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

View File

@ -7,16 +7,15 @@ members = [
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-examples/*",
"candle-wasm-tests",
]
exclude = ["candle-flash-attn", "candle-kernels"]
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"
[workspace.package]
@ -34,8 +33,7 @@ anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.16.0", package = "candle-gemm" }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -53,6 +51,7 @@ rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false }
@ -62,6 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug]
inherits = "release"

View File

@ -51,22 +51,26 @@ For more advanced examples, please have a look at the following section.
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
@ -94,10 +98,15 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
Run them using commands like:
```
@ -129,8 +138,18 @@ And then head over to
<!--- ANCHOR: useful_libraries --->
## Useful Libraries
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request.
@ -155,16 +174,21 @@ If you have an addition to this list, please submit a pull request.
- Phi v1.5.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- T5.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- Computer Vision Models.
- DINOv2.
- EfficientNet.
- yolo-v3.
- yolo-v8.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
@ -202,6 +226,7 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ

View File

@ -12,6 +12,9 @@ compute_cap
8.9
```
You can also compile the Cuda kernels for a specific compute cap using the
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.

View File

@ -13,6 +13,8 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -28,6 +30,8 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -39,3 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -39,6 +39,14 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,

View File

@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -36,6 +47,8 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
@ -55,6 +68,11 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
@ -103,7 +121,6 @@ impl Tensor {
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
@ -116,6 +133,15 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
@ -140,10 +166,16 @@ impl Tensor {
if node.is_variable() {
continue;
}
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph). The only drawback would be if we wanted to support grad of grad but
// this is out of scope.
let grad = grads
.remove(node)
.expect("candle internal error - grad not populated");
// https://github.com/huggingface/candle/issues/1241
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -198,7 +230,44 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv1D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv2D {
arg,
kernel,
@ -228,8 +297,18 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
@ -374,7 +453,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -461,17 +540,47 @@ impl Tensor {
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -25,6 +25,33 @@ impl ParamsConv1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
@ -160,6 +187,49 @@ impl Tensor {
}
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()

View File

@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
@ -1256,6 +1256,74 @@ impl Map1 for Im2Col {
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
// Output shape: [b_size, c_out, l_out].
let dst_elems = p.c_out * l_out * p.b_size;
let dst = vec![T::zero(); dst_elems];
let dst_s0 = p.c_out * l_out;
let dst_s1 = l_out;
let dst_s2 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
let cont_s0 = p.l_in * p.c_in;
let cont_s1 = p.c_in;
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
for c_idx in 0..p.c_in {
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
for k_idx in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
let out_idx = l_idx * p.stride + k_idx * p.dilation;
if out_idx < p.padding {
continue;
}
let out_idx = out_idx - p.padding;
if out_idx < l_out {
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@ -2435,6 +2503,16 @@ impl BackendStorage for CpuStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d(
&self,
l: &Layout,
@ -2539,25 +2617,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
}
}

View File

@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0.set_seed(seed).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
@ -1806,6 +1808,16 @@ impl BackendStorage for CudaStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
#[cfg(not(feature = "cudnn"))]
fn conv2d(
&self,
@ -2169,7 +2181,7 @@ impl BackendStorage for CudaStorage {
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.

View File

@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
pub trait NdArray {
@ -128,10 +130,23 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -140,21 +155,20 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -178,8 +192,19 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
}
}
}
@ -206,8 +231,18 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
Ok(Storage::Metal(storage))
}
}
}
@ -231,6 +266,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -244,6 +283,10 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -255,6 +298,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -266,6 +314,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}

View File

@ -14,6 +14,9 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(f, "Tensor[")?;
@ -476,6 +479,9 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(

View File

@ -79,6 +79,16 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,

View File

@ -0,0 +1,223 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, Shape};
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -142,6 +142,9 @@ pub enum Error {
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
@ -149,6 +152,9 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@ -156,6 +162,9 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -49,9 +49,12 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
@ -87,6 +90,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -125,3 +134,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -250,8 +250,6 @@ impl Tensor {
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -90,6 +90,16 @@ pub enum Op {
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
@ -174,6 +184,18 @@ pub trait CustomOp1 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
@ -209,6 +231,20 @@ pub trait CustomOp2 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -251,6 +287,22 @@ pub trait CustomOp3 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -536,13 +588,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
@ -632,6 +684,8 @@ impl UnaryOpT for Gelu {
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
@ -666,6 +720,40 @@ impl UnaryOpT for Erf {
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
@ -887,6 +975,10 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -193,6 +193,50 @@ impl Object {
_ => Err(self),
}
}
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
}))
}
}
impl TryFrom<Object> for String {
@ -565,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}
@ -623,50 +668,10 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
};
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => continue,
};
let (callable, args) = match value.reduce() {
Ok(callable_args) => callable_args,
_ => continue,
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor"
&& class_name == "_rebuild_from_type_v2" =>
{
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => continue,
};
match rebuild_args(args) {
Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone();
path.push(file_path);
tensor_infos.push(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
})
}
Err(err) => {
eprintln!("skipping {name}: {err:?}")
}
match value.into_tensor_info(name, &dir_name) {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
Ok(None) => {}
Err(err) => eprintln!("skipping: {err:?}"),
}
}
}
@ -723,3 +728,16 @@ impl PthTensors {
Ok(Some(tensor))
}
}
/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}

View File

@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -29,6 +29,7 @@ impl TryFrom<u32> for Magic {
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
@ -39,6 +40,7 @@ impl VersionedMagic {
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
@ -84,7 +86,9 @@ pub struct Content {
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
@ -284,7 +288,9 @@ impl Value {
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
@ -381,11 +387,15 @@ impl Content {
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut metadata = HashMap::new();
@ -407,7 +417,7 @@ impl Content {
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 => {
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()

View File

@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {

View File

@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}
@ -94,28 +69,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
@ -123,28 +88,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}

View File

@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -203,7 +203,7 @@ impl Shape {
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
@ -511,154 +511,119 @@ impl ShapeWithOneHole for ((),) {
}
}
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
if prod_d == 0 {
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
}
if el_count % prod_d != 0 {
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
}
Ok(el_count / prod_d)
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
Ok((hole_size(el_count, d1, &self)?, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
Ok((d1, hole_size(el_count, d1, &self)?).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d3, d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d4, d).into())
}
}

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
Metal(MetalStorage),
}
impl Storage {
@ -18,6 +19,10 @@ impl Storage {
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.try_clone(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -25,6 +30,7 @@ impl Storage {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
Self::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
@ -32,6 +38,7 @@ impl Storage {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
Self::Metal(storage) => storage.dtype(),
}
}
@ -65,6 +72,10 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Metal(storage))
}
}
}
@ -78,6 +89,10 @@ impl Storage {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -91,6 +106,10 @@ impl Storage {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -112,6 +131,10 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -135,6 +158,10 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Metal(storage))
}
}
}
@ -148,6 +175,10 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Metal(storage))
}
}
}
@ -161,6 +192,10 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
Self::Metal(storage) => {
let (storage, shape) = c.metal_fwd(storage, l)?;
Ok((Self::Metal(storage), shape))
}
}
}
@ -181,6 +216,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -205,6 +244,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -219,6 +262,10 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -239,6 +286,10 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -270,6 +321,10 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -279,6 +334,33 @@ impl Storage {
}
}
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,
@ -297,6 +379,10 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -324,6 +410,10 @@ impl Storage {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -348,6 +438,10 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -366,6 +460,10 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -379,6 +477,10 @@ impl Storage {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Metal(storage))
}
}
}
@ -392,6 +494,10 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Metal(storage))
}
}
}
@ -415,6 +521,10 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Metal(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -441,6 +551,10 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -465,6 +579,10 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -489,6 +607,10 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -510,6 +632,10 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -537,6 +663,10 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -556,6 +686,9 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

View File

@ -6,7 +6,7 @@ use crate::op::{
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
/// Unique identifier for tensors.
@ -385,11 +385,21 @@ impl Tensor {
step: D,
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
crate::bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
while current < end {
data.push(current);
current += step;
if step >= D::zero() {
while current < end {
data.push(current);
current += step;
}
} else {
while current > end {
data.push(current);
current += step;
}
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)
@ -449,7 +459,7 @@ impl Tensor {
/// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies.
pub(crate) fn track_op(&self) -> bool {
pub fn track_op(&self) -> bool {
self.is_variable || self.op.is_some()
}
@ -467,6 +477,12 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -513,6 +529,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -540,6 +557,73 @@ impl Tensor {
Ok(inp)
}
/// Creates grids of coordinates specified by the 1D inputs.
///
/// # Arguments
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
/// ```rust
/// use candle_core::{Tensor, Device, Shape};
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
///
/// assert_eq!(grids_xy.len(), 2);
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
///
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
///
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
///
/// # Errors
///
/// * Will return `Err` if `args` contains less than 2 tensors.
///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
}
let args: Vec<_> = if xy_indexing {
args.iter().rev().collect()
} else {
args.iter().collect()
};
let mut shape = Vec::with_capacity(args.len());
for arg in args.iter() {
shape.push(arg.as_ref().dims1()?)
}
let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() {
let mut ones = vec![1usize; args.len()];
ones[idx] = shape[idx];
let arg = args[idx].as_ref().reshape(ones)?;
let mut repeats = shape.clone();
repeats[idx] = 1;
let repeated_tensor = arg.repeat(repeats)?;
grids.push(repeated_tensor);
}
if xy_indexing {
grids.reverse();
}
Ok(grids)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
@ -615,15 +699,23 @@ impl Tensor {
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg: "start + len > dim_len",
}
.bt())?
let err = |msg| {
Err::<(), _>(
Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg,
}
.bt(),
)
};
if start > dims[dim] {
err("start > dim_len")?
}
if start.saturating_add(len) > dims[dim] {
err("start + len > dim_len")?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())
@ -764,6 +856,20 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale
}
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
@ -1111,14 +1217,16 @@ impl Tensor {
op: "scatter-add (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
let storage = self.storage().scatter_add(
self.layout(),
@ -1190,7 +1298,8 @@ impl Tensor {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
})?
}
.bt())?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
@ -1224,7 +1333,8 @@ impl Tensor {
op: "index-add (self, source)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
@ -1235,7 +1345,8 @@ impl Tensor {
op: "index-add (ids, source))",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
let storage = self.storage().index_add(
self.layout(),
@ -1283,7 +1394,8 @@ impl Tensor {
op: "gather",
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
})?
}
.bt())?
}
let storage =
self.storage()
@ -1357,6 +1469,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1387,6 +1500,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1427,6 +1541,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1590,6 +1705,24 @@ impl Tensor {
}
}
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let t = tensor.get_on_dim(1, 0)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
/// let t = tensor.get_on_dim(1, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
/// let t = tensor.get_on_dim(0, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
let dim = dim.to_index(self.shape(), "get_on_dim")?;
self.narrow(dim, index, 1)?.squeeze(dim)
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
///
@ -1698,17 +1831,23 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
@ -1720,7 +1859,14 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
@ -1728,6 +1874,9 @@ impl Tensor {
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!("not implemented yet")
}
};
let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ {
@ -2127,11 +2276,56 @@ impl Tensor {
}
}
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
crate::bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![self];
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
Tensor::cat(&v, dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
}
}
/// Run the `forward` method of `m` on `self`.
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
m.forward(self)
}
/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
@ -2246,6 +2440,23 @@ impl Tensor {
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
}
macro_rules! bin_trait {

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,6 +15,12 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}

View File

@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
pub fn metal_is_available() -> bool {
cfg!(feature = "metal")
}
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}

View File

@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(())
}
@ -479,17 +495,103 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
test_device!(
conv1d_small,
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu
conv2d_non_square_gpu,
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -192,6 +192,84 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
Ok(())
}
@ -237,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
test_device!(
simple_grad,
simple_grad_cpu,
simple_grad_gpu,
simple_grad_metal
);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
#[test]
fn strided_blocks() -> Result<()> {

9
candle-core/tests/npy.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
x = np.arange(10)
# Write a npy file.
np.save("test.npy", x)
# Write multiple values to a npz file.
values = { "x": x, "x_plus_one": x + 1 }
np.savez("test.npz", **values)

View File

@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu
avg_pool2d_pytorch_gpu,
avg_pool2d_pytorch_metal
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu
upsample_nearest2d_gpu,
upsample_nearest2d_metal
);

View File

@ -0,0 +1,24 @@
use candle_core::{DType, Result, Tensor};
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}

View File

@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
@ -161,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -1035,33 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
@ -1073,3 +1135,27 @@ fn randn_hasneg() -> Result<()> {
}
Ok(())
}
#[test]
fn pad_with_same() -> Result<()> {
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
let t0 = t.pad_with_same(0, 1, 2)?;
assert_eq!(
t0.to_vec2::<f32>()?,
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
);
let t1 = t.pad_with_same(1, 1, 2)?;
assert_eq!(
t1.to_vec2::<f32>()?,
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
);
Ok(())
}
#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}

BIN
candle-core/tests/test.npy Normal file

Binary file not shown.

BIN
candle-core/tests/test.npz Normal file

Binary file not shown.

View File

@ -4,7 +4,9 @@
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used.
use crate::vision::Dataset;
use candle::{DType, Device, Result, Tensor};
use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{BufReader, Read};
@ -60,3 +62,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
labels: 10,
})
}
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "cifar10".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("plain_text/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("plain_text/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -16,11 +16,13 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
@ -54,7 +56,20 @@ cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "reinforcement-learning"
required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -5,11 +5,11 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
@ -19,10 +19,6 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -38,6 +34,10 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
@ -60,34 +60,27 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}

View File

@ -0,0 +1,19 @@
# candle-blip
The
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
model can generate captions for an input image.
## Running on an example
```bash
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
Running on CPU, to run on GPU, build this example with `--features cuda`
loaded image Tensor[dims 3, 384, 384; f32]
model built
several cyclists are riding down a road with cars behind them%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -0,0 +1,154 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;
use candle_transformers::models::quantized_blip;
use tokenizers::Tokenizer;
enum Model {
M(blip::BlipForConditionalGeneration),
Q(quantized_blip::BlipForConditionalGeneration),
}
impl Model {
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
match self {
Self::M(m) => m.text_decoder().forward(xs, img_xs),
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
}
}
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
}
const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
if args.quantized {
let api = api.model("lmz/candle-blip".to_string());
api.get("blip-image-captioning-large-q4k.gguf")?
} else {
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
}
Some(model) => model.into(),
};
let tokenizer = match args.tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let config = blip::Config::image_captioning_large();
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::M(model))
};
let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -0,0 +1,59 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convmixer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-convmixer".into());
api.get("convmixer_1024_20_ks9_p14.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convmixer::c1024_20(1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -0,0 +1,45 @@
# candle-jina-bert
Jina-Bert is a general large language model with a context size of 8192, [model
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
it can be used for two different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
> ...
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
> Tensor[[1, 7, 768], f32]
```
## Similarities
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example jina-bert --release
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
> score: 0.78 'I love pasta' 'Do you like pizza?'
> score: 0.68 'I love pasta' 'The new movie is awesome'
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -0,0 +1,180 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model = match &self.model {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -6,9 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod model;
use candle_transformers::models::llama2_c as model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
@ -19,6 +20,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
Ok(())
}
enum Model {
Llama(Llama),
QLlama(QLlama),
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
}
}
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
@ -241,24 +257,66 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
("freq_cis_real".to_string(), freq_cis_real),
("freq_cis_imag".to_string(), freq_cis_imag),
]
.into_iter()
.collect(),
candle::DType::F32,
&candle::Device::Cpu,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
(vb, config)
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
(vb, config)
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
@ -273,7 +331,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let start_gen = std::time::Instant::now();
for index in 0.. {
if tokens.len() >= model.config.seq_len {
if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };

View File

@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

View File

@ -0,0 +1,38 @@
# candle-marian-mt
`marian-mt` is a neural machine translation model. In this example it is used to
translate text from French to English. See the associated [model
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
the model itself.
## Running an example
```bash
cargo run --example marian-mt --release -- \
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
```
```
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.</s>
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,152 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Big,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
tokenizer_dec: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "big")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// Text to be translated
#[arg(long)]
text: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let mut tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.push(config.eos_token_id);
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
use rand::prelude::*;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
@ -95,7 +95,7 @@ impl ConvNet {
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward(&xs, train)?.apply(&self.fc2)
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
}
}

View File

@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
@ -292,7 +294,7 @@ impl EncodecConv1d {
},
vb.pp("conv"),
)?,
NormType::None => conv1d(
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
@ -305,9 +307,17 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
// If we add support for NormType "time_group_norm", we should add some normalization here.
Ok(xs)
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}

View File

@ -0,0 +1,10 @@
## Using ONNX models in Candle
This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
You can run the example with the following command:
```bash
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -0,0 +1,78 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{IndexOp, D};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
#[arg(long)]
model: Option<String>,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
},
};
let model = candle_onnx::read_file(model)?;
let graph = model.graph.as_ref().unwrap();
let mut inputs = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
let output = outputs.remove(&graph.output[0].name).unwrap();
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -0,0 +1,87 @@
use anyhow::Result;
use candle::{Device, Tensor};
use clap::{Parser, Subcommand};
#[derive(Subcommand, Debug, Clone)]
enum Command {
Print {
#[arg(long)]
file: String,
},
SimpleEval {
#[arg(long)]
file: String,
},
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[command(subcommand)]
command: Command,
}
pub fn main() -> Result<()> {
let args = Args::parse();
match args.command {
Command::Print { file } => {
let model = candle_onnx::read_file(file)?;
println!("{model:?}");
let graph = model.graph.unwrap();
for node in graph.node.iter() {
println!("{node:?}");
}
}
Command::SimpleEval { file } => {
let model = candle_onnx::read_file(file)?;
let graph = model.graph.as_ref().unwrap();
let constants: std::collections::HashSet<_> =
graph.initializer.iter().map(|i| i.name.as_str()).collect();
let mut inputs = std::collections::HashMap::new();
for input in graph.input.iter() {
use candle_onnx::onnx::tensor_proto::DataType;
if constants.contains(input.name.as_str()) {
continue;
}
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
println!("input {}: {value:?}", input.name);
inputs.insert(input.name.clone(), value);
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
for (name, value) in outputs.iter() {
println!("output {name}: {value:?}")
}
}
}
Ok(())
}

View File

@ -41,3 +41,16 @@ def median(arr):
else:
return arr[n//2]
```
This also supports the [Puffin Phi v2
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
```
$ cargo run --example phi --release -- \
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
--sample-len 200 --model puffin-phi-v2 --quantized
USER: What would you do on a sunny day in Paris?
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
for a cup of coffee and to soak up the sun!"
```

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
@ -28,6 +28,7 @@ struct TextGeneration {
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
@ -40,6 +41,7 @@ impl TextGeneration {
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
@ -49,6 +51,7 @@ impl TextGeneration {
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
@ -56,20 +59,24 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -110,6 +117,16 @@ impl TextGeneration {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
#[value(name = "1")]
V1,
#[value(name = "1.5")]
V1_5,
PuffinPhiV2,
PhiHermes,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -121,6 +138,10 @@ struct Args {
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
@ -140,15 +161,21 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")]
model_id: String,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "refs/pr/18")]
revision: String,
#[arg(long, default_value = "1.5")]
model: WhichModel,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
quantized: bool,
@ -189,20 +216,62 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
"lmz/candle-quantized-phi".to_string()
} else {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
}
}
}
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => {
if args.quantized {
"main".to_string()
} else {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
}
}
}
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
api.model("lmz/candle-quantized-phi".to_string())
.get("model-q4k.gguf")?
match args.model {
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
}
} else {
repo.get("model.safetensors")?
match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
}
}
}
};
@ -210,7 +279,12 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::v1_5();
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = QMixFormer::new(&config, vb)?;
@ -231,6 +305,7 @@ fn main() -> Result<()> {
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;

View File

@ -1,5 +1,7 @@
# candle-quantized-t5
## Seq2Seq example
This example uses a quantized version of the t5 model.
```bash
@ -8,6 +10,8 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
Eine schöne Kerze.
```
## Generating Quantized weight files
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
@ -16,8 +20,11 @@ generate quantized weight files from the original safetensors file by using the
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
To use a different model, specify the `model-id`. For example, you can use
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
## Using custom models
To use a different model, specify the `model-id`.
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
```bash
$ cargo run --example quantized-t5 --release -- \
@ -26,6 +33,7 @@ $ cargo run --example quantized-t5 --release -- \
--temperature 0
...
Although their flight is weak, they run quickly through the tree canopy.
```
By default, it will look for `model.gguf` and `config.json`, but you can specify
custom local or remote `weight-file` and `config-file`s:
@ -40,3 +48,16 @@ cargo run --example quantized-t5 --release -- \
...
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
```
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example quantized-t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
--prompt "<2de> How are you, my friend?" \
--temperature 0
...
Wie geht es dir, mein Freund?
```

View File

@ -173,7 +173,11 @@ fn main() -> Result<()> {
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id) as u32]
.to_vec();
let temperature = if args.temperature <= 0. {
None
} else {

View File

@ -12,6 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
@ -24,7 +25,7 @@ enum Prompt {
One(String),
}
#[derive(Clone, Debug, Copy, ValueEnum)]
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
@ -48,6 +49,10 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
}
impl Which {
@ -62,7 +67,28 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
Self::Mistral7b | Self::Mistral7bInstruct => true,
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
}
fn is_zephyr(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
}
@ -81,7 +107,7 @@ struct Args {
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 100)]
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
@ -174,6 +200,13 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
),
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -184,31 +217,6 @@ impl Args {
}
}
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@ -295,7 +303,12 @@ fn main() -> anyhow::Result<()> {
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => 1,
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@ -303,6 +316,7 @@ fn main() -> anyhow::Result<()> {
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
@ -311,10 +325,11 @@ fn main() -> anyhow::Result<()> {
};
let mut pre_prompt_tokens = vec![];
loop {
for prompt_index in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
let is_interactive = matches!(prompt, Prompt::Interactive);
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
@ -325,7 +340,13 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
if args.which.is_mistral() {
if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
} else {
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
}
} else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
@ -333,7 +354,8 @@ fn main() -> anyhow::Result<()> {
}
};
print!("{}", &prompt_str);
let tokens = tokenizer
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
@ -363,11 +385,15 @@ fn main() -> anyhow::Result<()> {
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
@ -384,11 +410,19 @@ fn main() -> anyhow::Result<()> {
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
@ -396,9 +430,8 @@ fn main() -> anyhow::Result<()> {
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
match prompt {

View File

@ -0,0 +1,16 @@
# candle-reinforcement-learning
Reinforcement Learning examples for candle.
This has been tested with `gymnasium` version `0.29.1`. You can install the
Python package with:
```bash
pip install "gymnasium[accept-rom-license]"
```
In order to run the example, use the following command. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate.
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
```

View File

@ -0,0 +1,308 @@
import gymnasium as gym
import numpy as np
from collections import deque
from PIL import Image
from multiprocessing import Process, Pipe
# atari_wrappers.py
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class ImageSaver(gym.Wrapper):
def __init__(self, env, img_path, rank):
gym.Wrapper.__init__(self, env)
self._cnt = 0
self._img_path = img_path
self._rank = rank
def step(self, action):
step_result = self.env.step(action)
obs, _, _, _ = step_result
img = Image.fromarray(obs, 'RGB')
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
self._cnt += 1
return step_result
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ClipRewardEnv(gym.RewardWrapper):
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env)
self.res = 84
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
def observation(self, obs):
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
resample=Image.BILINEAR), dtype=np.uint8)
return frame.reshape((self.res, self.res, 1))
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
assert shp[2] == 1 # can only stack 1-channel frames
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k): self.frames.append(ob)
return self.observation()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self.observation(), reward, done, info
def observation(self):
assert len(self.frames) == self.k
return np.concatenate(self.frames, axis=2)
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
"""Configure environment for DeepMind-style Atari.
Note: this does not include frame stacking!"""
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
if episode_life:
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
return env
# envs.py
def make_env(env_id, img_dir, seed, rank):
def _thunk():
env = gym.make(env_id)
env.reset(seed=(seed + rank))
if img_dir is not None:
env = ImageSaver(env, img_dir, rank)
env = wrap_deepmind(env)
env = WrapPyTorch(env)
return env
return _thunk
class WrapPyTorch(gym.ObservationWrapper):
def __init__(self, env=None):
super(WrapPyTorch, self).__init__(env)
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
def observation(self, observation):
return observation.transpose(2, 0, 1)
# vecenv.py
class VecEnv(object):
"""
Vectorized environment base class
"""
def step(self, vac):
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
raise NotImplementedError
def reset(self):
"""
Reset all environments
"""
raise NotImplementedError
def close(self):
pass
# subproc_vec_env.py
def worker(remote, env_fn_wrapper):
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.action_space, env.observation_space))
else:
raise NotImplementedError
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns):
"""
envs: list of gym environments to run in subprocesses
"""
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
for p in self.ps:
p.start()
self.remotes[0].send(('get_spaces', None))
self.action_space, self.observation_space = self.remotes[0].recv()
def step(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
results = [remote.recv() for remote in self.remotes]
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
@property
def num_envs(self):
return len(self.remotes)
# Create the environment.
def make(env_name, img_dir, num_processes):
envs = SubprocVecEnv([
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
])
return envs

View File

@ -0,0 +1,451 @@
use std::collections::VecDeque;
use std::fmt::Display;
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
pub struct OuNoise {
mu: f64,
theta: f64,
sigma: f64,
state: Tensor,
}
impl OuNoise {
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
Ok(Self {
mu,
theta,
sigma,
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
})
}
pub fn sample(&mut self) -> Result<Tensor> {
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
self.state = (&self.state + dx)?;
Ok(self.state.clone())
}
}
#[derive(Clone)]
struct Transition {
state: Tensor,
action: Tensor,
reward: Tensor,
next_state: Tensor,
terminated: bool,
truncated: bool,
}
impl Transition {
fn new(
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) -> Self {
Self {
state: state.clone(),
action: action.clone(),
reward: reward.clone(),
next_state: next_state.clone(),
terminated,
truncated,
}
}
}
pub struct ReplayBuffer {
buffer: VecDeque<Transition>,
capacity: usize,
size: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
size: 0,
}
}
pub fn push(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
if self.size == self.capacity {
self.buffer.pop_front();
} else {
self.size += 1;
}
self.buffer.push_back(Transition::new(
state, action, reward, next_state, terminated, truncated,
));
}
#[allow(clippy::type_complexity)]
pub fn random_batch(
&self,
batch_size: usize,
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
let states: Vec<Tensor> = transitions
.iter()
.map(|t| t.state.unsqueeze(0))
.collect::<Result<_>>()?;
let actions: Vec<Tensor> = transitions
.iter()
.map(|t| t.action.unsqueeze(0))
.collect::<Result<_>>()?;
let rewards: Vec<Tensor> = transitions
.iter()
.map(|t| t.reward.unsqueeze(0))
.collect::<Result<_>>()?;
let next_states: Vec<Tensor> = transitions
.iter()
.map(|t| t.next_state.unsqueeze(0))
.collect::<Result<_>>()?;
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
Ok(Some((
Tensor::cat(&states, 0)?,
Tensor::cat(&actions, 0)?,
Tensor::cat(&rewards, 0)?,
Tensor::cat(&next_states, 0)?,
terminateds,
truncateds,
)))
}
}
}
fn track(
varmap: &mut VarMap,
vb: &VarBuilder,
target_prefix: &str,
network_prefix: &str,
dims: &[(usize, usize)],
tau: f64,
) -> Result<()> {
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.weight"),
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
)?;
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.bias"),
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
)?;
}
Ok(())
}
struct Actor<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Actor<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?)
.add(func(|xs| xs.tanh()));
Ok::<Sequential, Error>(seq)
};
let network = make_network("actor")?;
let target_network = make_network("target-actor")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor) -> Result<Tensor> {
self.network.forward(state)
}
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
self.target_network.forward(state)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-actor",
"actor",
&self.dims,
tau,
)
}
}
struct Critic<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Critic<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?);
Ok::<Sequential, Error>(seq)
};
let network = make_network("critic")?;
let target_network = make_network("target-critic")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.network.forward(&xs)
}
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.target_network.forward(&xs)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-critic",
"critic",
&self.dims,
tau,
)
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct DDPG<'a> {
actor: Actor<'a>,
actor_optim: AdamW,
critic: Critic<'a>,
critic_optim: AdamW,
gamma: f64,
tau: f64,
replay_buffer: ReplayBuffer,
ou_noise: OuNoise,
size_state: usize,
size_action: usize,
pub train: bool,
}
impl DDPG<'_> {
#[allow(clippy::too_many_arguments)]
pub fn new(
device: &Device,
size_state: usize,
size_action: usize,
train: bool,
actor_lr: f64,
critic_lr: f64,
gamma: f64,
tau: f64,
buffer_capacity: usize,
ou_noise: OuNoise,
) -> Result<Self> {
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
varmap
.data()
.lock()
.unwrap()
.iter()
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
.collect::<Vec<Var>>()
};
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
let actor_optim = AdamW::new(
filter_by_prefix(&actor.varmap, "actor"),
ParamsAdamW {
lr: actor_lr,
..Default::default()
},
)?;
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
let critic_optim = AdamW::new(
filter_by_prefix(&critic.varmap, "critic"),
ParamsAdamW {
lr: critic_lr,
..Default::default()
},
)?;
Ok(Self {
actor,
actor_optim,
critic,
critic_optim,
gamma,
tau,
replay_buffer: ReplayBuffer::new(buffer_capacity),
ou_noise,
size_state,
size_action,
train,
})
}
pub fn remember(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
self.replay_buffer
.push(state, action, reward, next_state, terminated, truncated)
}
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
.forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
} else {
actions
};
actions.squeeze(0)?.to_scalar::<f32>()
}
pub fn train(&mut self, batch_size: usize) -> Result<()> {
let (states, actions, rewards, next_states, _, _) =
match self.replay_buffer.random_batch(batch_size)? {
Some(v) => v,
_ => return Ok(()),
};
let q_target = self
.critic
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
let q = self.critic.forward(&states, &actions)?;
let diff = (q_target - q)?;
let critic_loss = diff.sqr()?.mean_all()?;
self.critic_optim.backward_step(&critic_loss)?;
let actor_loss = self
.critic
.forward(&states, &self.actor.forward(&states)?)?
.mean_all()?
.neg()?;
self.actor_optim.backward_step(&actor_loss)?;
self.critic.track(self.tau)?;
self.actor.track(self.tau)?;
Ok(())
}
}

View File

@ -0,0 +1,112 @@
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use candle::{Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
/// The return value for a step.
#[derive(Debug)]
pub struct Step<A> {
pub state: Tensor,
pub action: A,
pub reward: f64,
pub terminated: bool,
pub truncated: bool,
}
impl<A: Copy> Step<A> {
/// Returns a copy of this step changing the observation tensor.
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
Step {
state: state.clone(),
action: self.action,
reward: self.reward,
terminated: self.terminated,
truncated: self.truncated,
}
}
}
/// An OpenAI Gym session.
pub struct GymEnv {
env: PyObject,
action_space: usize,
observation_space: Vec<usize>,
}
fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
let action_space = if let Ok(val) = action_space.getattr("n") {
val.extract()?
} else {
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
action_space[0]
};
let observation_space = env.getattr("observation_space")?;
let observation_space = observation_space.getattr("shape")?.extract()?;
Ok(GymEnv {
env: env.into(),
action_space,
observation_space,
})
})
.map_err(w)
}
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
}
/// Applies an environment step using the specified action.
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
&self,
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;
let truncated: bool = step.get_item(3)?.extract()?;
Ok((state, reward, terminated, truncated))
})
.map_err(w)?;
let state = Tensor::new(state, &Device::Cpu)?;
Ok(Step {
state,
action,
reward,
terminated,
truncated,
})
}
/// Returns the number of allowed actions for this environment.
pub fn action_space(&self) -> usize {
self.action_space
}
/// Returns the shape of the observation tensors.
pub fn observation_space(&self) -> &[usize] {
&self.observation_space
}
}

View File

@ -0,0 +1,144 @@
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod gym_env;
mod vec_gym_env;
mod ddpg;
use candle::{Device, Result, Tensor};
use clap::Parser;
use rand::Rng;
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let env = gym_env::GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = ddpg::DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}

View File

@ -0,0 +1,91 @@
#![allow(unused)]
//! Vectorized version of the gym environment.
use candle::{DType, Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
#[derive(Debug)]
pub struct Step {
pub obs: Tensor,
pub reward: Tensor,
pub is_done: Tensor,
}
pub struct VecGymEnv {
env: PyObject,
action_space: usize,
observation_space: Vec<usize>,
}
fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {
let sys = py.import("sys")?;
let path = sys.getattr("path")?;
let _ = path.call_method1(
"append",
("candle-examples/examples/reinforcement-learning",),
)?;
let gym = py.import("atari_wrappers")?;
let make = gym.getattr("make")?;
let env = make.call1((name, img_dir, nprocesses))?;
let action_space = env.getattr("action_space")?;
let action_space = action_space.getattr("n")?.extract()?;
let observation_space = env.getattr("observation_space")?;
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
let observation_space =
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
Ok(VecGymEnv {
env: env.into(),
action_space,
observation_space,
})
})
.map_err(w)
}
pub fn reset(&self) -> Result<Tensor> {
let obs = Python::with_gil(|py| {
let obs = self.env.call_method0(py, "reset")?;
let obs = obs.call_method0(py, "flatten")?;
obs.extract::<Vec<f32>>(py)
})
.map_err(w)?;
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
}
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action,), None)?;
let step = step.as_ref(py);
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
let reward: Vec<f32> = step.get_item(1)?.extract()?;
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
Ok((obs, reward, is_done))
})
.map_err(w)?;
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
.to_dtype(DType::F32)?;
let reward = Tensor::new(reward, &Device::Cpu)?;
let is_done = Tensor::new(is_done, &Device::Cpu)?;
Ok(Step {
obs,
reward,
is_done,
})
}
pub fn action_space(&self) -> usize {
self.action_space
}
pub fn observation_space(&self) -> &[usize] {
&self.observation_space
}
}

View File

@ -0,0 +1,40 @@
# candle-replit-code: code completion specialized model.
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
language model specialized for code completion. This model uses 3.3B parameters
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
## Running some example
```bash
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
```
This produces the following output.
```
def fibonacci(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
print(a, end=' ')
a, b = b, a+b
print()
def fibonacci_loop(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
result = []
a, b = 0, 1
while a < n:
result.append(a)
a, b = b, a+b
return result
def fibonacci_generator(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
yield a
a, b = b, a+b
```

View File

@ -0,0 +1,265 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mpt::{Config, Model as M};
use candle_transformers::models::quantized_mpt::Model as Q;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
M(M),
Q(Q),
}
impl Model {
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::M(model) => model.forward(xs),
Self::Q(model) => model.forward(xs),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "lmz/candle-replit-code".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
repo.get("model-replit-code-v1_5-q4k.gguf")?
} else {
repo.get("model.safetensors")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,19 @@
# candle-resnet
A candle implementation of inference using a pre-trained [ResNet](https://arxiv.org/abs/1512.03385).
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example resnet --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
tiger, Panthera tigris : 90.21%
tiger cat : 8.93%
lion, king of beasts, Panthera leo: 0.35%
leopard, Panthera pardus: 0.16%
jaguar, panther, Panthera onca, Felis onca: 0.09%
```

View File

@ -0,0 +1,12 @@
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt
m = torchvision.models.resnet50(pretrained=True)
stt.save_file(m.state_dict(), 'resnet50.safetensors')
m = torchvision.models.resnet101(pretrained=True)
stt.save_file(m.state_dict(), 'resnet101.safetensors')
m = torchvision.models.resnet152(pretrained=True)
stt.save_file(m.state_dict(), 'resnet152.safetensors')

View File

@ -0,0 +1,90 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::resnet;
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
#[value(name = "18")]
Resnet18,
#[value(name = "34")]
Resnet34,
#[value(name = "50")]
Resnet50,
#[value(name = "101")]
Resnet101,
#[value(name = "152")]
Resnet152,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-resnet".into());
let filename = match args.which {
Which::Resnet18 => "resnet18.safetensors",
Which::Resnet34 => "resnet34.safetensors",
Which::Resnet50 => "resnet50.safetensors",
Which::Resnet101 => "resnet101.safetensors",
Which::Resnet152 => "resnet152.safetensors",
};
api.get(filename)?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
let model = match args.which {
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
};
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -50,6 +50,9 @@ cached.
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`.
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
(e.g., A100/H100, RTX 3090/4090).
## Image to Image Pipeline
...

View File

@ -5,12 +5,26 @@
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
## Sentence embedding example:
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" \
--prompt "<2de> How are you, my friend?" \
--decode --temperature 0
...
Wie geht es dir, mein Freund?
```
## Sentence embedding example
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

View File

@ -104,6 +104,17 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
@ -143,7 +154,6 @@ fn main() -> Result<()> {
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
@ -173,7 +183,12 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id)
as u32]
.to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

View File

@ -0,0 +1,154 @@
use image::{DynamicImage, ImageBuffer};
use serde::Deserialize;
use std::collections::HashMap;
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ProcessorConfig {
do_resize: bool,
height: u32,
width: u32,
do_rescale: bool,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
do_resize: true,
height: 384,
width: 384,
do_rescale: true,
do_normalize: true,
image_mean: vec![0.5, 0.5, 0.5],
image_std: vec![0.5, 0.5, 0.5],
}
}
}
pub struct ViTImageProcessor {
do_resize: bool,
height: u32,
width: u32,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl ViTImageProcessor {
pub fn new(config: &ProcessorConfig) -> Self {
Self {
do_resize: config.do_resize,
height: config.height,
width: config.width,
do_normalize: config.do_normalize,
image_mean: config.image_mean.clone(),
image_std: config.image_std.clone(),
}
}
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let images = self.load_images(images)?;
let resized_images: Vec<DynamicImage> = if self.do_resize {
images
.iter()
.map(|image| self.resize(image.clone(), None).unwrap())
.collect()
} else {
images
};
let normalized_images: Vec<Tensor> = if self.do_normalize {
resized_images
.iter()
.map(|image| self.normalize(image.clone(), None, None).unwrap())
.collect()
} else {
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
resized_images.iter().map(|image| image.to_rgb8()).collect();
let data = resized_images
.into_iter()
.map(|image| image.into_raw())
.collect::<Vec<Vec<u8>>>();
data.iter()
.map(|image| {
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
.unwrap()
.permute((2, 0, 1))
.unwrap()
})
.collect::<Vec<Tensor>>()
};
Tensor::stack(&normalized_images, 0)
}
fn resize(
&self,
image: image::DynamicImage,
size: Option<HashMap<String, u32>>,
) -> Result<image::DynamicImage> {
let (height, width) = match &size {
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
None => (&self.height, &self.width),
};
let resized_image =
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
Ok(resized_image)
}
fn normalize(
&self,
image: image::DynamicImage,
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
) -> Result<Tensor> {
let mean = match mean {
Some(mean) => mean,
None => self.image_mean.clone(),
};
let std = match std {
Some(std) => std,
None => self.image_std.clone(),
};
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
let image = image.to_rgb8();
let data = image.into_raw();
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let data =
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
(data.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
let img = image::io::Reader::open(path)?.decode().unwrap();
images.push(img);
}
Ok(images)
}
}

View File

@ -0,0 +1,132 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
}
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
model: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "base")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Text to be translated
#[arg(long)]
image: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let encoder_xs = model.encoder().forward(&image)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == decoder_config.eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -0,0 +1,16 @@
# candle-trocr
`TrOCR` is a transformer OCR Model. In this example it is used to
transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.
## Running an example
```bash
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
```
```
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```

View File

@ -0,0 +1,13 @@
## VGG Model Implementation
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
You can run the example with the following command:
```bash
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
```
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).

View File

@ -0,0 +1,77 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{ModuleT, VarBuilder};
use candle_transformers::models::vgg::{Models, Vgg};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Vgg13,
Vgg16,
Vgg19,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let api = hf_hub::api::sync::Api::new()?;
let repo = match args.which {
Which::Vgg13 => "timm/vgg13.tv_in1k",
Which::Vgg16 => "timm/vgg16.tv_in1k",
Which::Vgg19 => "timm/vgg19.tv_in1k",
};
let api = api.model(repo.into());
let filename = "model.safetensors";
let model_file = api.get(filename)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = match args.which {
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
};
let logits = model.forward_t(&image, /*train=*/ false)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -0,0 +1,20 @@
# candle-vit
Vision Transformer (ViT) model implementation following the lines of
[vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example vit --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
tiger, Panthera tigris : 100.00%
tiger cat : 0.00%
jaguar, panther, Panthera onca, Felis onca: 0.00%
leopard, Panthera pardus: 0.00%
lion, king of beasts, Panthera leo: 0.00%
```

View File

@ -0,0 +1,59 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::VarBuilder;
use candle_transformers::models::vit;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/vit-base-patch16-224".into());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = vit::Model::new(&vit::Config::vit_base_patch16_224(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -345,7 +345,7 @@ enum Task {
Translate,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
@ -361,15 +361,27 @@ enum WhichModel {
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
true
Self::Tiny
| Self::Base
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
}
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
@ -385,6 +397,9 @@ impl WhichModel {
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
}
}
@ -496,17 +511,25 @@ fn main() -> Result<()> {
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
let config = repo.get("config.json")?;
let tokenizer = if args.model == WhichModel::LargeV3 {
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
} else {
repo.get("tokenizer.json")?
};
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};
(config, tokenizer, model, sample)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_bytes = include_bytes!("melfilters.bytes");
let mel_bytes = match config.num_mel_bins {
80 => include_bytes!("melfilters.bytes").as_slice(),
128 => include_bytes!("melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
@ -522,12 +545,15 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
println!("loaded mel: {:?}", mel.dims());
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;

Binary file not shown.

View File

@ -0,0 +1,268 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::yi::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "6b")]
L6b,
#[value(name = "34b")]
L34b,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "01-ai/Yi-6B")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "6b")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.which {
Which::L6b => Config::config_6b(),
Which::L34b => Config::config_34b(),
};
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -108,7 +108,7 @@ pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
}
enum Bl {
Layer(Box<dyn candle_nn::Module + Send>),
Layer(Box<dyn candle_nn::Module + Send + Sync>),
Route(Vec<usize>),
Shortcut(usize),
Yolo(usize, Vec<(usize, usize)>),

View File

@ -1,7 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
@ -76,7 +74,6 @@ impl Module for Upsample {
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
span: tracing::Span,
}
@ -96,11 +93,10 @@ impl ConvBlock {
groups: 1,
dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
Ok(Self {
conv,
bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
@ -110,7 +106,6 @@ impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}

View File

@ -2,17 +2,28 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
let device = Device::cuda_if_available(0)?;
if !device.is_cuda() {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(device)
Ok(Device::Cpu)
}
}

View File

@ -84,12 +84,19 @@ fn main() -> Result<()> {
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
cu_files.iter().any(|(cu_file, _)| {
let out_modified = out_file.metadata().unwrap().modified().unwrap();
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
in_modified.duration_since(out_modified).is_ok()
})
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
@ -100,12 +107,19 @@ fn main() -> Result<()> {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
@ -203,13 +217,21 @@ fn set_cuda_include_dir() -> Result<()> {
#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
@ -220,16 +242,19 @@ fn compute_cap() -> Result<usize> {
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
@ -243,30 +268,21 @@ fn compute_cap() -> Result<usize> {
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -233,8 +233,8 @@ impl FlashAttnVarLen {
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
let seqlens_q = match &*seqlens_q {
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_q must be a cuda tensor"),
};
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_q.slice(o1..o2),
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
let seqlens_k = match &*seqlens_k {
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_k must be a cuda tensor"),
};
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_k.slice(o1..o2),

View File

@ -12,5 +12,6 @@ license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
glob = "0.3.1"
rayon = "1.7.0"
rayon = "1.7.0"

View File

@ -1,4 +1,5 @@
use std::io::Write;
fn main() {
println!("cargo:rerun-if-changed=build.rs");
@ -23,6 +24,8 @@ fn main() {
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
@ -100,107 +103,52 @@ mod cuda {
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
// let start = std::time::Instant::now();
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let mut lines = out.lines();
assert_eq!(lines.next().unwrap(), "compute_cap");
let cap = lines.next().unwrap().replace('.', "");
cap.parse::<usize>().unwrap()
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
panic!("nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}.");
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str.parse::<usize>().unwrap();
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
}else{
false
};
if ignore{
None
}else{
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}})
.collect::<Vec<_>>();
}
})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
@ -220,4 +168,76 @@ mod cuda {
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
}

View File

@ -1,6 +1,53 @@
#include "cuda_utils.cuh"
#include<stdint.h>
template <typename S, typename T>
__device__ void cast_(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = inp[i];
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = inp[strided_i];
}
}
}
template <typename S, typename T, typename I>
__device__ void cast_through(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(static_cast<I>(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(static_cast<I>(inp[strided_i]));
}
}
}
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -9,22 +56,10 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = inp[strided_i]; \
} \
} \
cast_<SRC_TYPENAME, DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \
#define CAST_BF_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
@ -32,25 +67,12 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = (DST_TYPENAME) (float) inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = (DST_TYPENAME) (float) inp[strided_i]; \
} \
} \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \
#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
// CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
@ -58,14 +80,15 @@ CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_BF_OP(__nv_bfloat16, __half, cast_bf16_f16)
CAST_BF_OP(__half, __nv_bfloat16, cast_f16_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
#endif
#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)
// CAST_OP(__half, uint8_t, cast_f16_u8 )
CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
CAST_OP(__half, uint32_t, cast_f16_u32)
CAST_OP(__half, float, cast_f16_f32)
CAST_OP(__half, double, cast_f16_f64)

View File

@ -0,0 +1,21 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

View File

@ -0,0 +1,3 @@
# candle-metal-kernels
This crate contains Metal kernels used from candle.

View File

@ -0,0 +1,61 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

Some files were not shown because too many files have changed in this diff Show More