Add support to UL2 model family (#1300)

* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Juarez Bochi
2023-11-09 12:55:09 -05:00
committed by GitHub
parent 6958384327
commit 18d30005c5
9 changed files with 71 additions and 15 deletions

View File

@ -9,6 +9,8 @@ $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate
9 tokens generated (2.42 token/s)
```
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
@ -22,7 +24,7 @@ cargo run --example t5 --release -- \
Wie geht es dir, mein Freund?
```
## Sentence embedding example:
## Sentence embedding example
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

View File

@ -104,6 +104,17 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};