Some polish.

This commit is contained in:
Nicolas Patry
2023-07-05 07:41:14 +00:00
parent 963c75cb89
commit d8f75ceeaa

View File

@ -621,18 +621,26 @@ struct Args {
#[arg(long)]
offline: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "This is an example sentence")]
prompt: String,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
}
#[tokio::main]
async fn main() -> Result<()> {
use tokenizers::Tokenizer;
let start = std::time::Instant::now();
println!("Building {:?}", start.elapsed());
let args = Args::parse();
let device = if args.cpu {
@ -672,29 +680,25 @@ async fn main() -> Result<()> {
api.get(&repo, "model.safetensors").await?,
)
};
println!("Building {:?}", start.elapsed());
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
println!("Config loaded {:?}", start.elapsed());
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
println!("Tokenizer loaded {:?}", start.elapsed());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let model = BertModel::load(&vb, &config)?;
println!("Loaded {:?}", start.elapsed());
let tokens = tokenizer
.encode("This is an example sentence", true)
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed());
for _ in 0..100 {
for _ in 0..args.n {
let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?;
println!("Took {:?}", start.elapsed());