Expose the t5 config fields + allow t5-large. (#1987)

This commit is contained in:
Laurent Mazare
2024-04-01 20:58:34 +02:00
committed by GitHub
parent ea0d8d3753
commit be9c200cbb
2 changed files with 18 additions and 16 deletions

View File

@ -22,6 +22,7 @@ const DTYPE: DType = DType::F32;
enum Which {
T5Base,
T5Small,
T5Large,
T5_3B,
Mt5Base,
Mt5Small,
@ -108,6 +109,7 @@ impl T5ModelBuilder {
let (default_model, default_revision) = match args.which {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5Large => ("t5-large", "main"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),