Extract T5 module and add main function to use it (#829)

* Extract t5 out of musicgen

* Add main for t5 module
This commit is contained in:
Juarez Bochi
2023-09-12 23:14:05 -07:00
committed by GitHub
parent e82fcf1c59
commit 9daa6dbe87
8 changed files with 184 additions and 21 deletions

View File

@ -13,7 +13,6 @@ extern crate accelerate_src;
mod encodec_model;
mod musicgen_model;
mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};

View File

@ -1,9 +1,10 @@
use crate::{encodec_model, t5_model};
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel,
pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
t5: crate::t5_model::Config,
t5: t5::Config,
encodec: crate::encodec_model::Config,
}
@ -387,7 +388,7 @@ impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(),
t5: t5::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
}
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;

View File

@ -0,0 +1,17 @@
# candle-t5
Generates embeddings using a T5 model. It doesn't support generation yet.
```bash
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
Loaded and encoded 2.014244792s
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
...
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
Tensor[[1, 9, 1024], f32]
Took 2.1363425s
```

View File

@ -0,0 +1,134 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::t5;
use anyhow::{anyhow, Error as E, Result};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
const DEFAULT_PROMPT: &str = "Translate English to German: That is good.";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// Compute embeddings for this prompt or use the DEFAULT_PROMPT.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?;
let api = api.repo(repo);
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = t5::T5EncoderModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
Ok(())
}

View File

@ -18,6 +18,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }

View File

@ -1,7 +1,10 @@
use candle::Tensor;
use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Activation {
#[default]
Gelu,
Relu,
Elu(f64),

View File

@ -7,4 +7,5 @@ pub mod llama;
pub mod quantized_llama;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod t5;
pub mod whisper;

View File

@ -1,11 +1,12 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Result, Tensor, D};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
d_model: usize,
@ -15,14 +16,15 @@ pub struct Config {
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
relative_attention_max_distance: Option<usize>,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
is_decoder: bool,
is_decoder: Option<bool>,
is_encoder_decoder: bool,
use_cache: bool,
use_cache: Option<bool>,
pad_token_id: usize,
eos_token_id: usize,
}
@ -38,14 +40,14 @@ impl Default for Config {
num_decoder_layers: None,
num_heads: 8,
relative_attention_num_buckets: 32,
relative_attention_max_distance: 128,
relative_attention_max_distance: Some(128),
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
is_decoder: false,
is_decoder: Some(false),
is_encoder_decoder: true,
use_cache: true,
use_cache: Some(true),
pad_token_id: 0,
eos_token_id: 1,
}
@ -63,16 +65,16 @@ impl Config {
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_decoder: Some(false),
is_encoder_decoder: true,
layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12),
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
relative_attention_max_distance: 128,
relative_attention_max_distance: Some(128),
relative_attention_num_buckets: 32,
use_cache: true,
use_cache: Some(true),
vocab_size: 32128,
}
}
@ -197,7 +199,7 @@ impl T5Attention {
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128),
inner_dim,
})
}
@ -343,7 +345,7 @@ impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder {
let cross_attn = if cfg.is_decoder.unwrap_or(false) {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else {
None
@ -417,6 +419,7 @@ impl T5Stack {
pub struct T5EncoderModel {
shared: Arc<Embedding>,
encoder: T5Stack,
pub device: Device,
}
impl T5EncoderModel {
@ -424,7 +427,11 @@ impl T5EncoderModel {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
Ok(Self { shared, encoder })
Ok(Self {
shared,
encoder,
device: vb.device().clone(),
})
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {