mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
T5 tweaks (#831)
* Use default values rather than options. * Avoid exposing the device field. * More tweaks.
This commit is contained in:
@ -30,7 +30,7 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
/// The model repository to use on the HuggingFace hub.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
@ -94,22 +94,10 @@ impl Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let _guard = if args.tracing {
|
|
||||||
println!("tracing...");
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
let device = &model.device;
|
|
||||||
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
let tokenizer = tokenizer
|
let tokenizer = tokenizer
|
||||||
.with_padding(None)
|
.with_padding(None)
|
||||||
@ -120,7 +108,7 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
for idx in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
@ -6,6 +6,18 @@ use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn default_relative_attention_max_distance() -> usize {
|
||||||
|
128
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_is_decoder() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_use_cache() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -16,15 +28,18 @@ pub struct Config {
|
|||||||
num_decoder_layers: Option<usize>,
|
num_decoder_layers: Option<usize>,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
relative_attention_num_buckets: usize,
|
relative_attention_num_buckets: usize,
|
||||||
relative_attention_max_distance: Option<usize>,
|
#[serde(default = "default_relative_attention_max_distance")]
|
||||||
|
relative_attention_max_distance: usize,
|
||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
feed_forward_proj: Activation,
|
feed_forward_proj: Activation,
|
||||||
is_decoder: Option<bool>,
|
#[serde(default = "default_is_decoder")]
|
||||||
|
is_decoder: bool,
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
use_cache: Option<bool>,
|
#[serde(default = "default_use_cache")]
|
||||||
|
use_cache: bool,
|
||||||
pad_token_id: usize,
|
pad_token_id: usize,
|
||||||
eos_token_id: usize,
|
eos_token_id: usize,
|
||||||
}
|
}
|
||||||
@ -40,14 +55,14 @@ impl Default for Config {
|
|||||||
num_decoder_layers: None,
|
num_decoder_layers: None,
|
||||||
num_heads: 8,
|
num_heads: 8,
|
||||||
relative_attention_num_buckets: 32,
|
relative_attention_num_buckets: 32,
|
||||||
relative_attention_max_distance: Some(128),
|
relative_attention_max_distance: 128,
|
||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
is_decoder: Some(false),
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
use_cache: Some(true),
|
use_cache: true,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
}
|
}
|
||||||
@ -65,16 +80,16 @@ impl Config {
|
|||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
is_decoder: Some(false),
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
num_decoder_layers: Some(12),
|
num_decoder_layers: Some(12),
|
||||||
num_heads: 12,
|
num_heads: 12,
|
||||||
num_layers: 12,
|
num_layers: 12,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
relative_attention_max_distance: Some(128),
|
relative_attention_max_distance: 128,
|
||||||
relative_attention_num_buckets: 32,
|
relative_attention_num_buckets: 32,
|
||||||
use_cache: Some(true),
|
use_cache: true,
|
||||||
vocab_size: 32128,
|
vocab_size: 32128,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -199,7 +214,7 @@ impl T5Attention {
|
|||||||
d_kv: cfg.d_kv,
|
d_kv: cfg.d_kv,
|
||||||
relative_attention_bias,
|
relative_attention_bias,
|
||||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||||
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128),
|
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -345,7 +360,7 @@ impl T5Block {
|
|||||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let vb = vb.pp("layer");
|
let vb = vb.pp("layer");
|
||||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||||
let cross_attn = if cfg.is_decoder.unwrap_or(false) {
|
let cross_attn = if cfg.is_decoder {
|
||||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -402,23 +417,20 @@ impl T5Stack {
|
|||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
|
||||||
|
|
||||||
let mut hidden_states = input_embeds;
|
let mut hidden_states = input_embeds;
|
||||||
let mut position_bias = None;
|
let mut position_bias = None;
|
||||||
for block in self.block.iter() {
|
for block in self.block.iter() {
|
||||||
(hidden_states, position_bias) =
|
(hidden_states, position_bias) =
|
||||||
block.forward(&hidden_states, position_bias.as_ref())?
|
block.forward(&hidden_states, position_bias.as_ref())?
|
||||||
}
|
}
|
||||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
self.final_layer_norm.forward(&hidden_states)
|
||||||
Ok(hidden_states)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct T5EncoderModel {
|
pub struct T5EncoderModel {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
pub device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
@ -433,7 +445,10 @@ impl T5EncoderModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let encoder_outputs = self.encoder.forward(input_ids)?;
|
self.encoder.forward(input_ids)
|
||||||
Ok(encoder_outputs)
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.device
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user