Commit Graph

25 Commits

Author SHA1 Message Date
b06e1a7e54 [nn] Move the Embedding and Activation parts. (#116)
* Share the Embedding and Activation parts.

* Tweak some activations.
2023-07-10 10:24:52 +01:00
9ce0f1c010 Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
2023-07-10 08:50:09 +01:00
0a2c82e301 Merge pull request #92 from LaurentMazare/sync_hub
Creating new sync Api for `candle-hub`.
2023-07-07 00:10:47 +02:00
115629fe08 Creating new sync Api for candle-hub.
- `api::Api` -> `api::tokio::api` (And created new `api::sync::Api`).
- Remove `tokio` from all our examples.
- Using similar codebase for now instead of ureq (for simplicity).
2023-07-06 15:15:25 +02:00
3f291bdf9d Enabling roberta for the example (it's the same model as Bert, with
just different naming.)
2023-07-06 13:25:21 +02:00
c297a50960 Add mkl support for matrix multiply. (#86)
* Fix some rebase issues.

* Use mkl instead.

* Use mkl in bert.

* Add the optional mkl feature.

* Conditional compilation based on the mkl feature.

* Add more mkl support.
2023-07-06 11:05:05 +01:00
2c3d871b2e Add a simpler way to specify the dim index for some ops. 2023-07-05 20:22:43 +01:00
174e57d216 Use avg pooling before the cosine similarity. 2023-07-05 17:05:50 +01:00
914e84deec Add some sentence similarity comparision to the bert example. 2023-07-05 16:49:57 +01:00
d8f75ceeaa Some polish. 2023-07-05 07:41:14 +00:00
963c75cb89 Adding offline mode. 2023-07-05 07:19:57 +00:00
43a007cba4 Upgrading bert example to work with bert-base-uncased.
- Always take weights from the hub
- Optional `model_id` + `revision` to use safetensors version
  potentially
- Optional loading for `bert-base-uncased` (`weight` vs `gamma`).
- Take the config from the hub.
2023-07-04 14:12:14 +00:00
a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
b6d179cc1c Allow for batch dimensions in the embedding layer. 2023-07-03 18:37:40 +01:00
9784d1ed9f Minor tweaks. 2023-07-03 18:31:55 +01:00
5524ca29cc Remove the fixed length hack. 2023-07-03 17:13:23 +01:00
1ea6690557 Bugfix for transpose. 2023-07-03 17:06:23 +01:00
a7f03a7bb6 Fix the layer norm to properly handle bias. 2023-07-03 16:45:03 +01:00
f379b8feae Get some embeddings out. 2023-07-03 16:11:16 +01:00
54850e7525 Get the tensors to be loaded properly. 2023-07-03 15:53:31 +01:00
ad52b0377c Add the varbuilder + check shapes. 2023-07-03 15:32:20 +01:00
f74bddca31 Model creation. 2023-07-03 14:09:46 +01:00
12ac9e1460 Complete (?) the forward pass. 2023-07-03 13:33:32 +01:00
d796945ad8 Add more to the forward pass. 2023-07-03 13:04:41 +01:00
2309c5fac5 Boilerplate code for Bert. 2023-07-03 12:17:06 +01:00