Commit Graph

7 Commits

Author SHA1 Message Date
62ef494dc1 Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks.

* Make the context contiguous if required.
2023-08-29 11:13:43 +01:00
aba1e90797 Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.

* Avoid some unnecessary groups checks.

* Move the tensor convolution bits.

* Properh handling of groups.

* Bump the crate version.

* And add a changelog.
2023-08-23 12:58:55 +01:00
c78ce76501 Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00
5d99026fd2 F16 support for stable diffusion (#488)
* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
2023-08-17 13:48:56 +01:00
c3176f0dfb Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
2023-08-17 12:16:40 +01:00
9af438ac1b Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
2023-08-13 15:58:26 +01:00
d34039e352 Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
2023-08-06 17:49:43 +01:00