Add a couple t5 models. (#1958)

This commit is contained in:
Laurent Mazare
2024-03-28 17:58:06 +01:00
committed by GitHub
parent cdc8b57b5c
commit c5092f2c29

View File

@ -12,12 +12,19 @@ use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use clap::Parser; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
T5Base,
T5Small,
T5_3B,
}
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -71,6 +78,10 @@ struct Args {
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
/// The model to be used.
#[arg(long, default_value = "t5-small")]
which: Which,
} }
struct T5ModelBuilder { struct T5ModelBuilder {
@ -82,8 +93,13 @@ struct T5ModelBuilder {
impl T5ModelBuilder { impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string(); let (default_model, default_revision) = match args.which {
let default_revision = "refs/pr/15".to_string(); Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5_3B => ("t5-3b", "main"),
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()), (Some(model_id), None) => (model_id, "main".to_string()),