diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 3794c22d..df8c3135 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -13,7 +13,6 @@ extern crate accelerate_src; mod encodec_model; mod musicgen_model; mod nn; -mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 7e272fd7..d6d8ae15 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,9 +1,10 @@ -use crate::{encodec_model, t5_model}; +use crate::encodec_model; use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{ embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, VarBuilder, }; +use candle_transformers::models::t5; // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 #[derive(Debug, Clone, PartialEq)] @@ -370,7 +371,7 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { - pub text_encoder: crate::t5_model::T5EncoderModel, + pub text_encoder: t5::T5EncoderModel, pub audio_encoder: crate::encodec_model::EncodecModel, pub decoder: MusicgenForCausalLM, cfg: GenConfig, @@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration { #[derive(Debug, Clone, PartialEq)] pub struct GenConfig { musicgen: Config, - t5: crate::t5_model::Config, + t5: t5::Config, encodec: crate::encodec_model::Config, } @@ -387,7 +388,7 @@ impl GenConfig { pub fn small() -> Self { Self { musicgen: Config::musicgen_small(), - t5: t5_model::Config::musicgen_small(), + t5: t5::Config::musicgen_small(), encodec: encodec_model::Config::musicgen_small(), } } @@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration { } pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result { - let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; + let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let audio_encoder = encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md new file mode 100644 index 00000000..66952395 --- /dev/null +++ b/candle-examples/examples/t5/README.md @@ -0,0 +1,17 @@ +# candle-t5 + +Generates embeddings using a T5 model. It doesn't support generation yet. + +```bash +$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1 +Loaded and encoded 2.014244792s +[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387], + [-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760], + [-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300], + ... + [-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306], + [-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384], + [ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]] +Tensor[[1, 9, 1024], f32] +Took 2.1363425s +``` \ No newline at end of file diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs new file mode 100644 index 00000000..bcba846d --- /dev/null +++ b/candle-examples/examples/t5/main.rs @@ -0,0 +1,134 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use candle_transformers::models::t5; + +use anyhow::{anyhow, Error as E, Result}; +use candle::{DType, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; +use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; +use tokenizers::Tokenizer; + +const DTYPE: DType = DType::F32; +const DEFAULT_PROMPT: &str = "Translate English to German: That is good."; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Run offline (you must have the files already cached) + #[arg(long)] + offline: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + /// Compute embeddings for this prompt or use the DEFAULT_PROMPT. + #[arg(long)] + prompt: Option, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let default_model = "t5-small".to_string(); + let default_revision = "refs/pr/15".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = if self.offline { + let cache = Cache::default().repo(repo); + ( + cache + .get("config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get("tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get("model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + let api = api.repo(repo); + ( + api.get("config.json")?, + api.get("tokenizer.json")?, + api.get("model.safetensors")?, + ) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let model = t5::T5EncoderModel::load(vb, &config)?; + Ok((model, tokenizer)) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); + + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + println!("Loaded and encoded {:?}", start.elapsed()); + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + Ok(()) +} diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 5062a717..2a8ec8ce 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -18,6 +18,7 @@ intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } +serde = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 0db3edc9..6e442458 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -1,7 +1,10 @@ use candle::Tensor; +use serde::Deserialize; -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] pub enum Activation { + #[default] Gelu, Relu, Elu(f64), diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f33d01e6..e2e0bf81 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -7,4 +7,5 @@ pub mod llama; pub mod quantized_llama; pub mod segment_anything; pub mod stable_diffusion; +pub mod t5; pub mod whisper; diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-transformers/src/models/t5.rs similarity index 94% rename from candle-examples/examples/musicgen/t5_model.rs rename to candle-transformers/src/models/t5.rs index 22f0a4f5..1454b7cc 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,11 +1,12 @@ // T5 Text Encoder // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use candle::{DType, Result, Tensor, D}; +use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder}; +use serde::Deserialize; use std::sync::Arc; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, d_model: usize, @@ -15,14 +16,15 @@ pub struct Config { num_decoder_layers: Option, num_heads: usize, relative_attention_num_buckets: usize, - relative_attention_max_distance: usize, + relative_attention_max_distance: Option, dropout_rate: f64, layer_norm_epsilon: f64, initializer_factor: f64, + #[serde(default)] feed_forward_proj: Activation, - is_decoder: bool, + is_decoder: Option, is_encoder_decoder: bool, - use_cache: bool, + use_cache: Option, pad_token_id: usize, eos_token_id: usize, } @@ -38,14 +40,14 @@ impl Default for Config { num_decoder_layers: None, num_heads: 8, relative_attention_num_buckets: 32, - relative_attention_max_distance: 128, + relative_attention_max_distance: Some(128), dropout_rate: 0.1, layer_norm_epsilon: 1e-6, initializer_factor: 1.0, feed_forward_proj: Activation::Relu, - is_decoder: false, + is_decoder: Some(false), is_encoder_decoder: true, - use_cache: true, + use_cache: Some(true), pad_token_id: 0, eos_token_id: 1, } @@ -63,16 +65,16 @@ impl Config { eos_token_id: 1, feed_forward_proj: Activation::Relu, initializer_factor: 1.0, - is_decoder: false, + is_decoder: Some(false), is_encoder_decoder: true, layer_norm_epsilon: 1e-6, num_decoder_layers: Some(12), num_heads: 12, num_layers: 12, pad_token_id: 0, - relative_attention_max_distance: 128, + relative_attention_max_distance: Some(128), relative_attention_num_buckets: 32, - use_cache: true, + use_cache: Some(true), vocab_size: 32128, } } @@ -197,7 +199,7 @@ impl T5Attention { d_kv: cfg.d_kv, relative_attention_bias, relative_attention_num_buckets: cfg.relative_attention_num_buckets, - relative_attention_max_distance: cfg.relative_attention_max_distance, + relative_attention_max_distance: cfg.relative_attention_max_distance.unwrap_or(128), inner_dim, }) } @@ -343,7 +345,7 @@ impl T5Block { fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { let vb = vb.pp("layer"); let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; - let cross_attn = if cfg.is_decoder { + let cross_attn = if cfg.is_decoder.unwrap_or(false) { Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) } else { None @@ -417,6 +419,7 @@ impl T5Stack { pub struct T5EncoderModel { shared: Arc, encoder: T5Stack, + pub device: Device, } impl T5EncoderModel { @@ -424,7 +427,11 @@ impl T5EncoderModel { let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?; - Ok(Self { shared, encoder }) + Ok(Self { + shared, + encoder, + device: vb.device().clone(), + }) } pub fn forward(&self, input_ids: &Tensor) -> Result {