From e4553fb355ffebe6781ea2d35ba0734a310cab9b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 13 Sep 2023 08:37:04 +0200 Subject: [PATCH] T5 tweaks (#831) * Use default values rather than options. * Avoid exposing the device field. * More tweaks. --- candle-examples/examples/t5/main.rs | 16 ++------- candle-transformers/src/models/t5.rs | 51 ++++++++++++++++++---------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index bcba846d..84be0204 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -30,7 +30,7 @@ struct Args { #[arg(long)] tracing: bool, - /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + /// The model repository to use on the HuggingFace hub. #[arg(long)] model_id: Option, @@ -94,22 +94,10 @@ impl Args { } fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - let args = Args::parse(); - let _guard = if args.tracing { - println!("tracing..."); - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; let start = std::time::Instant::now(); let (model, mut tokenizer) = args.build_model_and_tokenizer()?; - let device = &model.device; let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); let tokenizer = tokenizer .with_padding(None) @@ -120,7 +108,7 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; println!("Loaded and encoded {:?}", start.elapsed()); for idx in 0..args.n { let start = std::time::Instant::now(); diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 3700f1e0..691817d1 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -6,6 +6,18 @@ use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module use serde::Deserialize; use std::sync::Arc; +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, @@ -16,15 +28,18 @@ pub struct Config { num_decoder_layers: Option, num_heads: usize, relative_attention_num_buckets: usize, - relative_attention_max_distance: Option, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, dropout_rate: f64, layer_norm_epsilon: f64, initializer_factor: f64, #[serde(default)] feed_forward_proj: Activation, - is_decoder: Option, + #[serde(default = "default_is_decoder")] + is_decoder: bool, is_encoder_decoder: bool, - use_cache: Option, + #[serde(default = "default_use_cache")] + use_cache: bool, pad_token_id: usize, eos_token_id: usize, } @@ -40,14 +55,14 @@ impl Default for Config { num_decoder_layers: None, num_heads: 8, relative_attention_num_buckets: 32, - relative_attention_max_distance: Some(128), + relative_attention_max_distance: 128, dropout_rate: 0.1, layer_norm_epsilon: 1e-6, initializer_factor: 1.0, feed_forward_proj: Activation::Relu, - is_decoder: Some(false), + is_decoder: false, is_encoder_decoder: true, - use_cache: Some(true), + use_cache: true, pad_token_id: 0, eos_token_id: 1, } @@ -65,16 +80,16 @@ impl Config { eos_token_id: 1, feed_forward_proj: Activation::Relu, initializer_factor: 1.0, - is_decoder: Some(false), + is_decoder: false, is_encoder_decoder: true, layer_norm_epsilon: 1e-6, num_decoder_layers: Some(12), num_heads: 12, num_layers: 12, pad_token_id: 0, - relative_attention_max_distance: Some(128), + relative_attention_max_distance: 128, relative_attention_num_buckets: 32, - use_cache: Some(true), + use_cache: true, vocab_size: 32128, } } @@ -199,7 +214,7 @@ impl T5Attention { d_kv: cfg.d_kv, relative_attention_bias, relative_attention_num_buckets: cfg.relative_attention_num_buckets, - relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128), + relative_attention_max_distance: cfg.relative_attention_max_distance, inner_dim, }) } @@ -345,7 +360,7 @@ impl T5Block { fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { let vb = vb.pp("layer"); let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; - let cross_attn = if cfg.is_decoder.unwrap_or(false) { + let cross_attn = if cfg.is_decoder { Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) } else { None @@ -402,23 +417,20 @@ impl T5Stack { fn forward(&self, input_ids: &Tensor) -> Result { let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); - let mut hidden_states = input_embeds; let mut position_bias = None; for block in self.block.iter() { (hidden_states, position_bias) = block.forward(&hidden_states, position_bias.as_ref())? } - let hidden_states = self.final_layer_norm.forward(&hidden_states)?; - Ok(hidden_states) + self.final_layer_norm.forward(&hidden_states) } } #[derive(Debug)] pub struct T5EncoderModel { encoder: T5Stack, - pub device: Device, + device: Device, } impl T5EncoderModel { @@ -433,7 +445,10 @@ impl T5EncoderModel { } pub fn forward(&self, input_ids: &Tensor) -> Result { - let encoder_outputs = self.encoder.forward(input_ids)?; - Ok(encoder_outputs) + self.encoder.forward(input_ids) + } + + pub fn device(&self) -> &Device { + &self.device } }