mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
T5 tweaks (#831)
* Use default values rather than options. * Avoid exposing the device field. * More tweaks.
This commit is contained in:
@ -30,7 +30,7 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
@ -94,22 +94,10 @@ impl Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
@ -120,7 +108,7 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
|
@ -6,6 +6,18 @@ use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn default_relative_attention_max_distance() -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn default_is_decoder() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_use_cache() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
@ -16,15 +28,18 @@ pub struct Config {
|
||||
num_decoder_layers: Option<usize>,
|
||||
num_heads: usize,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: Option<usize>,
|
||||
#[serde(default = "default_relative_attention_max_distance")]
|
||||
relative_attention_max_distance: usize,
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
is_decoder: Option<bool>,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: Option<bool>,
|
||||
#[serde(default = "default_use_cache")]
|
||||
use_cache: bool,
|
||||
pad_token_id: usize,
|
||||
eos_token_id: usize,
|
||||
}
|
||||
@ -40,14 +55,14 @@ impl Default for Config {
|
||||
num_decoder_layers: None,
|
||||
num_heads: 8,
|
||||
relative_attention_num_buckets: 32,
|
||||
relative_attention_max_distance: Some(128),
|
||||
relative_attention_max_distance: 128,
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
is_decoder: Some(false),
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: Some(true),
|
||||
use_cache: true,
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
}
|
||||
@ -65,16 +80,16 @@ impl Config {
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: Some(false),
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
num_decoder_layers: Some(12),
|
||||
num_heads: 12,
|
||||
num_layers: 12,
|
||||
pad_token_id: 0,
|
||||
relative_attention_max_distance: Some(128),
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_num_buckets: 32,
|
||||
use_cache: Some(true),
|
||||
use_cache: true,
|
||||
vocab_size: 32128,
|
||||
}
|
||||
}
|
||||
@ -199,7 +214,7 @@ impl T5Attention {
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128),
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
})
|
||||
}
|
||||
@ -345,7 +360,7 @@ impl T5Block {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder.unwrap_or(false) {
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
@ -402,23 +417,20 @@ impl T5Stack {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter() {
|
||||
(hidden_states, position_bias) =
|
||||
block.forward(&hidden_states, position_bias.as_ref())?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
self.final_layer_norm.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct T5EncoderModel {
|
||||
encoder: T5Stack,
|
||||
pub device: Device,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
@ -433,7 +445,10 @@ impl T5EncoderModel {
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_outputs = self.encoder.forward(input_ids)?;
|
||||
Ok(encoder_outputs)
|
||||
self.encoder.forward(input_ids)
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user