Commit Graph

1240 Commits

Author SHA1 Message Date
19e52e5007 T5 Wasm (#918)
* init t5 wasm model

* split workers for each model

* clean up

* add some ui

* readme

* index

* typo

* remove cache param, clear_kv_cache

* add max_length as param

* add model tasks option to ui

* add method to load quantized gguf from buffer

* Add quantized wasm module

* add quantized models to UI, dynamic import wasms

* link to quantized

* fix copy

* fix ModelEncoder

* fix README.md
2023-09-22 15:31:10 +01:00
8601537e31 Add slice-scatter. (#927)
* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
2023-09-22 12:18:16 +01:00
a96878f235 cuda cast i64 (#925) 2023-09-21 19:52:39 +01:00
aa8ec06fd2 Add the t5-xxl version. (#924) 2023-09-21 14:48:13 +01:00
b43ca493f6 Add more quantized flan t5 variants (#923)
* Add the quantized flan-t5-large variant.

* Add more sizes.
2023-09-21 13:23:30 +01:00
3b557765e8 T5 quantized example (#922)
* Load gguf files for the quantized t5.

* Add the quantized t5 example.

* Allow for loading local files.

* Add some support for quantizing safetensor files.

* Transpose before quantizing.

* Quantized t5.

* Retrieve the weights from the hub.
2023-09-21 12:33:15 +01:00
2619c4307f Add a quantized version of the t5 model. (#921) 2023-09-21 11:13:39 +01:00
c89b82b2d4 Add a clear cache function to the t5 model. (#919) 2023-09-21 09:01:06 +01:00
7b26e513f1 Add the erf function. (#917) 2023-09-21 06:19:10 +01:00
ab1d40ea97 Add more t5 tracing. (#915) 2023-09-20 20:20:54 +01:00
3a0d3e05df Add more t5 tracing. (#914)
* Add more t5 tracing.

* Rever the sm change.
2023-09-20 16:37:51 +01:00
9b24d89d2d Tracing mode for T5. (#913)
* Tracing mode for T5.

* Tracing for the linear layer.
2023-09-20 15:03:35 +01:00
fb1c2ac535 Add flash-attn support. (#912)
* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
2023-09-20 14:07:55 +01:00
728e167334 Add details on wuerstchen. (#911) 2023-09-20 13:09:35 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
f685b2231c Add some missing biases. (#908) 2023-09-20 10:14:51 +01:00
c0b49d5a50 Wuerstchen parameter tweaks. (#907) 2023-09-20 09:26:24 +01:00
098dd0d1e9 fix: add missingtop_p in llama_multiprocess (#905) 2023-09-20 08:54:56 +01:00
05626ef492 Flan T5: Read lm_head when word embeddings are not tied (#903)
* Read lm_head when word embeddings are not tied

* Fix formatting

* Address comments
2023-09-19 22:36:47 +01:00
67a486d18d Line-up the wuerstchen model with the python implementation. (#901)
* Line-up the wuerstchen model with the python implementation.

* Missing cos.

* Fix the picture denormalization.
2023-09-19 21:59:44 +01:00
7ad82b87e4 BERT Wasm (#902)
* implement wasm module

* add example to workspace

* add UI explore semantic similiarity

* change status messages

* formatting

* minor changes
2023-09-19 21:31:37 +01:00
8696f64bae Fix T5 kv cache (#899)
* Fix T5 kv cache

* Add argument for decoder prompt

* Fix range
2023-09-19 20:36:15 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
34f2ecbc3b Fix the leaky relu. (#898) 2023-09-19 18:17:17 +01:00
4f91c8e109 Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
2023-09-19 15:09:47 +01:00
06e46d7c3b Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
2023-09-19 14:13:05 +01:00
9cf26c5cff Fix typo in error_manage.md (#888)
occured -> occurred
2023-09-19 07:14:15 +01:00
aaa9d4ed6c W decoding. (#893)
* W decoding.

* Add the diffusion loop.

* Use the appropriate config.
2023-09-19 07:13:44 +01:00
92db8cecd3 Specialized attention module for Wuerstchen. (#890)
* Specialized attention module for Wuerstchen.

* Reshaping ops.

* Attention processor.

* Finish the forward pass.

* Hook the new attention processor.

* Get the prior forward pass to work.

* Make it contiguous.
2023-09-18 21:16:09 +01:00
1542e92629 T5: Add option to override use_cache from config (#892)
* Add option to override use_cache from config

* Disable cache by default and cleanup code
2023-09-18 20:20:21 +01:00
82a98f6da0 Prior denoising. (#889) 2023-09-18 16:51:38 +01:00
5082954c52 Fix the W clip embeddings. (#887)
* Fix the W clip embeddings.

* Add the specialized ddpm scheduler.
2023-09-18 14:50:14 +01:00
7dd8e12472 Bump the crate versions to v0.2.3. (#886)
* Bump the crate version.

* Also update the python bindings.
2023-09-18 12:14:03 +01:00
12696b7b2d Fix typos in SAM WASM example (#884) 2023-09-18 09:41:50 +01:00
ef8cd8fea0 Update the candle-gemm version. (#885) 2023-09-18 09:36:20 +01:00
03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00
c2b866172a More Wuerstchen fixes. (#882)
* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
2023-09-17 22:08:11 +01:00
06cc329e71 Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
2023-09-17 15:59:27 +01:00
5f83c13f17 Add the DDPM scheduler. (#877)
* Add the DDPM scheduler.

* Minor tweaks.
2023-09-17 15:03:01 +01:00
db3e9dae04 Wuerstchen main (#876)
* Wuerstchen main.

* More of the wuerstchen cli example.

* Paella creation.

* Build the prior model.

* Fix the weight file names.
2023-09-17 12:46:38 +01:00
7f65af1f0d Avoid re-encoding the input in the T5 example. (#875) 2023-09-17 10:25:54 +01:00
eeb54716dd Tweaks for the T5 example. (#874) 2023-09-17 10:05:15 +01:00
1a276b5da7 Add a KV cache to T5. (#873)
* Add a KV cache to T5.

* Suggest using release mode.

* Use the kv cache in decoding.

* Add a comment.
2023-09-17 08:00:45 +01:00
8658df3485 Generate *.pyi stubs for PyO3 wrapper (#870)
* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
2023-09-16 17:23:38 +01:00
7cafca835a readme tweaks. (#867) 2023-09-16 07:22:24 +01:00
04ca2b9ebd Update README + SAM (#866)
* use serde-wasm-bindgen, faster serialization

* update readme with demos
2023-09-16 07:34:13 +02:00
635012d770 Do not backprop through argmin/argmax. (#865) 2023-09-15 22:15:40 +01:00
3e49f8fce5 Implement T5 decoding (#864)
* Load t5 decoder

* Run enc, dec, and lm head, but no cross attn

* Cross-attention over key_value_states

* New arg for decoder input ids

* Add mask, don't forward position biases through decoder

* Update t5 examples

* Clippy + rustfmt
2023-09-15 22:05:12 +02:00
c2007ac88f W fixes. (#862) 2023-09-15 15:11:11 +01:00
30be5b6660 Replication pad (#861)
* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
2023-09-15 14:06:21 +01:00