Extract T5 module and add main function to use it (#829)

* Extract t5 out of musicgen

* Add main for t5 module
This commit is contained in:
Juarez Bochi
2023-09-12 23:14:05 -07:00
committed by GitHub
parent e82fcf1c59
commit 9daa6dbe87
8 changed files with 184 additions and 21 deletions

View File

@ -0,0 +1,17 @@
# candle-t5
Generates embeddings using a T5 model. It doesn't support generation yet.
```bash
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
Loaded and encoded 2.014244792s
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
...
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
Tensor[[1, 9, 1024], f32]
Took 2.1363425s
```

View File

@ -0,0 +1,134 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::t5;
use anyhow::{anyhow, Error as E, Result};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
const DEFAULT_PROMPT: &str = "Translate English to German: That is good.";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// Compute embeddings for this prompt or use the DEFAULT_PROMPT.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?;
let api = api.repo(repo);
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = t5::T5EncoderModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
Ok(())
}