mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Expose the t5 config fields + allow t5-large. (#1987)
This commit is contained in:
@ -22,6 +22,7 @@ const DTYPE: DType = DType::F32;
|
|||||||
enum Which {
|
enum Which {
|
||||||
T5Base,
|
T5Base,
|
||||||
T5Small,
|
T5Small,
|
||||||
|
T5Large,
|
||||||
T5_3B,
|
T5_3B,
|
||||||
Mt5Base,
|
Mt5Base,
|
||||||
Mt5Small,
|
Mt5Small,
|
||||||
@ -108,6 +109,7 @@ impl T5ModelBuilder {
|
|||||||
let (default_model, default_revision) = match args.which {
|
let (default_model, default_revision) = match args.which {
|
||||||
Which::T5Base => ("t5-base", "main"),
|
Which::T5Base => ("t5-base", "main"),
|
||||||
Which::T5Small => ("t5-small", "refs/pr/15"),
|
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||||
|
Which::T5Large => ("t5-large", "main"),
|
||||||
Which::T5_3B => ("t5-3b", "main"),
|
Which::T5_3B => ("t5-3b", "main"),
|
||||||
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||||
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||||
|
@ -70,26 +70,26 @@ where
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
d_model: usize,
|
pub d_model: usize,
|
||||||
d_kv: usize,
|
pub d_kv: usize,
|
||||||
d_ff: usize,
|
pub d_ff: usize,
|
||||||
num_layers: usize,
|
pub num_layers: usize,
|
||||||
num_decoder_layers: Option<usize>,
|
pub num_decoder_layers: Option<usize>,
|
||||||
num_heads: usize,
|
pub num_heads: usize,
|
||||||
relative_attention_num_buckets: usize,
|
pub relative_attention_num_buckets: usize,
|
||||||
#[serde(default = "default_relative_attention_max_distance")]
|
#[serde(default = "default_relative_attention_max_distance")]
|
||||||
relative_attention_max_distance: usize,
|
pub relative_attention_max_distance: usize,
|
||||||
dropout_rate: f64,
|
pub dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
pub layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
pub initializer_factor: f64,
|
||||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||||
feed_forward_proj: ActivationWithOptionalGating,
|
pub feed_forward_proj: ActivationWithOptionalGating,
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
tie_word_embeddings: bool,
|
pub tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
is_decoder: bool,
|
pub is_decoder: bool,
|
||||||
is_encoder_decoder: bool,
|
pub is_encoder_decoder: bool,
|
||||||
#[serde(default = "default_use_cache")]
|
#[serde(default = "default_use_cache")]
|
||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
|
Reference in New Issue
Block a user