T5 tweaks (#831)

* Use default values rather than options.

* Avoid exposing the device field.

* More tweaks.
This commit is contained in:
Laurent Mazare
2023-09-13 08:37:04 +02:00
committed by GitHub
parent d801e1d564
commit e4553fb355
2 changed files with 35 additions and 32 deletions

View File

@ -30,7 +30,7 @@ struct Args {
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending /// The model repository to use on the HuggingFace hub.
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
@ -94,22 +94,10 @@ impl Args {
} }
fn main() -> Result<()> { fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer let tokenizer = tokenizer
.with_padding(None) .with_padding(None)
@ -120,7 +108,7 @@ fn main() -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed()); println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n { for idx in 0..args.n {
let start = std::time::Instant::now(); let start = std::time::Instant::now();

View File

@ -6,6 +6,18 @@ use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
fn default_relative_attention_max_distance() -> usize {
128
}
fn default_is_decoder() -> bool {
false
}
fn default_use_cache() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
vocab_size: usize, vocab_size: usize,
@ -16,15 +28,18 @@ pub struct Config {
num_decoder_layers: Option<usize>, num_decoder_layers: Option<usize>,
num_heads: usize, num_heads: usize,
relative_attention_num_buckets: usize, relative_attention_num_buckets: usize,
relative_attention_max_distance: Option<usize>, #[serde(default = "default_relative_attention_max_distance")]
relative_attention_max_distance: usize,
dropout_rate: f64, dropout_rate: f64,
layer_norm_epsilon: f64, layer_norm_epsilon: f64,
initializer_factor: f64, initializer_factor: f64,
#[serde(default)] #[serde(default)]
feed_forward_proj: Activation, feed_forward_proj: Activation,
is_decoder: Option<bool>, #[serde(default = "default_is_decoder")]
is_decoder: bool,
is_encoder_decoder: bool, is_encoder_decoder: bool,
use_cache: Option<bool>, #[serde(default = "default_use_cache")]
use_cache: bool,
pad_token_id: usize, pad_token_id: usize,
eos_token_id: usize, eos_token_id: usize,
} }
@ -40,14 +55,14 @@ impl Default for Config {
num_decoder_layers: None, num_decoder_layers: None,
num_heads: 8, num_heads: 8,
relative_attention_num_buckets: 32, relative_attention_num_buckets: 32,
relative_attention_max_distance: Some(128), relative_attention_max_distance: 128,
dropout_rate: 0.1, dropout_rate: 0.1,
layer_norm_epsilon: 1e-6, layer_norm_epsilon: 1e-6,
initializer_factor: 1.0, initializer_factor: 1.0,
feed_forward_proj: Activation::Relu, feed_forward_proj: Activation::Relu,
is_decoder: Some(false), is_decoder: false,
is_encoder_decoder: true, is_encoder_decoder: true,
use_cache: Some(true), use_cache: true,
pad_token_id: 0, pad_token_id: 0,
eos_token_id: 1, eos_token_id: 1,
} }
@ -65,16 +80,16 @@ impl Config {
eos_token_id: 1, eos_token_id: 1,
feed_forward_proj: Activation::Relu, feed_forward_proj: Activation::Relu,
initializer_factor: 1.0, initializer_factor: 1.0,
is_decoder: Some(false), is_decoder: false,
is_encoder_decoder: true, is_encoder_decoder: true,
layer_norm_epsilon: 1e-6, layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12), num_decoder_layers: Some(12),
num_heads: 12, num_heads: 12,
num_layers: 12, num_layers: 12,
pad_token_id: 0, pad_token_id: 0,
relative_attention_max_distance: Some(128), relative_attention_max_distance: 128,
relative_attention_num_buckets: 32, relative_attention_num_buckets: 32,
use_cache: Some(true), use_cache: true,
vocab_size: 32128, vocab_size: 32128,
} }
} }
@ -199,7 +214,7 @@ impl T5Attention {
d_kv: cfg.d_kv, d_kv: cfg.d_kv,
relative_attention_bias, relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets, relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128), relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim, inner_dim,
}) })
} }
@ -345,7 +360,7 @@ impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer"); let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder.unwrap_or(false) { let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else { } else {
None None
@ -402,23 +417,20 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?; let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds; let mut hidden_states = input_embeds;
let mut position_bias = None; let mut position_bias = None;
for block in self.block.iter() { for block in self.block.iter() {
(hidden_states, position_bias) = (hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())? block.forward(&hidden_states, position_bias.as_ref())?
} }
let hidden_states = self.final_layer_norm.forward(&hidden_states)?; self.final_layer_norm.forward(&hidden_states)
Ok(hidden_states)
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct T5EncoderModel { pub struct T5EncoderModel {
encoder: T5Stack, encoder: T5Stack,
pub device: Device, device: Device,
} }
impl T5EncoderModel { impl T5EncoderModel {
@ -433,7 +445,10 @@ impl T5EncoderModel {
} }
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_outputs = self.encoder.forward(input_ids)?; self.encoder.forward(input_ids)
Ok(encoder_outputs) }
pub fn device(&self) -> &Device {
&self.device
} }
} }