mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a couple t5 models. (#1958)
This commit is contained in:
@ -12,12 +12,19 @@ use anyhow::{Error as E, Result};
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
T5Base,
|
||||||
|
T5Small,
|
||||||
|
T5_3B,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -71,6 +78,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(long, default_value = "t5-small")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct T5ModelBuilder {
|
struct T5ModelBuilder {
|
||||||
@ -82,8 +93,13 @@ struct T5ModelBuilder {
|
|||||||
impl T5ModelBuilder {
|
impl T5ModelBuilder {
|
||||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_model = "t5-small".to_string();
|
let (default_model, default_revision) = match args.which {
|
||||||
let default_revision = "refs/pr/15".to_string();
|
Which::T5Base => ("t5-base", "main"),
|
||||||
|
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||||
|
Which::T5_3B => ("t5-3b", "main"),
|
||||||
|
};
|
||||||
|
let default_model = default_model.to_string();
|
||||||
|
let default_revision = default_revision.to_string();
|
||||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
Reference in New Issue
Block a user