T5 tweaks (#831)

* Use default values rather than options.

* Avoid exposing the device field.

* More tweaks.
This commit is contained in:
Laurent Mazare
2023-09-13 08:37:04 +02:00
committed by GitHub
parent d801e1d564
commit e4553fb355
2 changed files with 35 additions and 32 deletions

View File

@ -30,7 +30,7 @@ struct Args {
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
@ -94,22 +94,10 @@ impl Args {
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer
.with_padding(None)
@ -120,7 +108,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();

View File

@ -6,6 +6,18 @@ use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module
use serde::Deserialize;
use std::sync::Arc;
fn default_relative_attention_max_distance() -> usize {
128
}
fn default_is_decoder() -> bool {
false
}
fn default_use_cache() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
@ -16,15 +28,18 @@ pub struct Config {
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
relative_attention_max_distance: Option<usize>,
#[serde(default = "default_relative_attention_max_distance")]
relative_attention_max_distance: usize,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
is_decoder: Option<bool>,
#[serde(default = "default_is_decoder")]
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: Option<bool>,
#[serde(default = "default_use_cache")]
use_cache: bool,
pad_token_id: usize,
eos_token_id: usize,
}
@ -40,14 +55,14 @@ impl Default for Config {
num_decoder_layers: None,
num_heads: 8,
relative_attention_num_buckets: 32,
relative_attention_max_distance: Some(128),
relative_attention_max_distance: 128,
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
is_decoder: Some(false),
is_decoder: false,
is_encoder_decoder: true,
use_cache: Some(true),
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
}
@ -65,16 +80,16 @@ impl Config {
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
initializer_factor: 1.0,
is_decoder: Some(false),
is_decoder: false,
is_encoder_decoder: true,
layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12),
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
relative_attention_max_distance: Some(128),
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: Some(true),
use_cache: true,
vocab_size: 32128,
}
}
@ -199,7 +214,7 @@ impl T5Attention {
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128),
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
@ -345,7 +360,7 @@ impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder.unwrap_or(false) {
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else {
None
@ -402,23 +417,20 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
(hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
Ok(hidden_states)
self.final_layer_norm.forward(&hidden_states)
}
}
#[derive(Debug)]
pub struct T5EncoderModel {
encoder: T5Stack,
pub device: Device,
device: Device,
}
impl T5EncoderModel {
@ -433,7 +445,10 @@ impl T5EncoderModel {
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_outputs = self.encoder.forward(input_ids)?;
Ok(encoder_outputs)
self.encoder.forward(input_ids)
}
pub fn device(&self) -> &Device {
&self.device
}
}