mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen * Add main for t5 module
This commit is contained in:
@ -13,7 +13,6 @@ extern crate accelerate_src;
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::{encodec_model, t5_model};
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: crate::t5_model::T5EncoderModel,
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: crate::t5_model::Config,
|
||||
t5: t5::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
@ -387,7 +388,7 @@ impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5_model::Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
|
17
candle-examples/examples/t5/README.md
Normal file
17
candle-examples/examples/t5/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-t5
|
||||
|
||||
Generates embeddings using a T5 model. It doesn't support generation yet.
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
|
||||
Loaded and encoded 2.014244792s
|
||||
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
|
||||
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
|
||||
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
|
||||
...
|
||||
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
|
||||
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
|
||||
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
|
||||
Tensor[[1, 9, 1024], f32]
|
||||
Took 2.1363425s
|
||||
```
|
134
candle-examples/examples/t5/main.rs
Normal file
134
candle-examples/examples/t5/main.rs
Normal file
@ -0,0 +1,134 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
const DEFAULT_PROMPT: &str = "Translate English to German: That is good.";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Compute embeddings for this prompt or use the DEFAULT_PROMPT.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let default_model = "t5-small".to_string();
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: t5::Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let model = t5::T5EncoderModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -18,6 +18,7 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -1,7 +1,10 @@
|
||||
use candle::Tensor;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
Relu,
|
||||
Elu(f64),
|
||||
|
@ -7,4 +7,5 @@ pub mod llama;
|
||||
pub mod quantized_llama;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod t5;
|
||||
pub mod whisper;
|
||||
|
@ -1,11 +1,12 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
d_model: usize,
|
||||
@ -15,14 +16,15 @@ pub struct Config {
|
||||
num_decoder_layers: Option<usize>,
|
||||
num_heads: usize,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
relative_attention_max_distance: Option<usize>,
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
is_decoder: bool,
|
||||
is_decoder: Option<bool>,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: bool,
|
||||
use_cache: Option<bool>,
|
||||
pad_token_id: usize,
|
||||
eos_token_id: usize,
|
||||
}
|
||||
@ -38,14 +40,14 @@ impl Default for Config {
|
||||
num_decoder_layers: None,
|
||||
num_heads: 8,
|
||||
relative_attention_num_buckets: 32,
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_max_distance: Some(128),
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
is_decoder: false,
|
||||
is_decoder: Some(false),
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
use_cache: Some(true),
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
}
|
||||
@ -63,16 +65,16 @@ impl Config {
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_decoder: Some(false),
|
||||
is_encoder_decoder: true,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
num_decoder_layers: Some(12),
|
||||
num_heads: 12,
|
||||
num_layers: 12,
|
||||
pad_token_id: 0,
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_max_distance: Some(128),
|
||||
relative_attention_num_buckets: 32,
|
||||
use_cache: true,
|
||||
use_cache: Some(true),
|
||||
vocab_size: 32128,
|
||||
}
|
||||
}
|
||||
@ -197,7 +199,7 @@ impl T5Attention {
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128),
|
||||
inner_dim,
|
||||
})
|
||||
}
|
||||
@ -343,7 +345,7 @@ impl T5Block {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
let cross_attn = if cfg.is_decoder.unwrap_or(false) {
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
@ -417,6 +419,7 @@ impl T5Stack {
|
||||
pub struct T5EncoderModel {
|
||||
shared: Arc<Embedding>,
|
||||
encoder: T5Stack,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
@ -424,7 +427,11 @@ impl T5EncoderModel {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
Ok(Self {
|
||||
shared,
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
Reference in New Issue
Block a user