mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Some polish.
This commit is contained in:
@ -621,18 +621,26 @@ struct Args {
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "This is an example sentence")]
|
||||
prompt: String,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
let start = std::time::Instant::now();
|
||||
println!("Building {:?}", start.elapsed());
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
@ -672,29 +680,25 @@ async fn main() -> Result<()> {
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
)
|
||||
};
|
||||
println!("Building {:?}", start.elapsed());
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("Config loaded {:?}", start.elapsed());
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
println!("Tokenizer loaded {:?}", start.elapsed());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||
let model = BertModel::load(&vb, &config)?;
|
||||
println!("Loaded {:?}", start.elapsed());
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode("This is an example sentence", true)
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for _ in 0..100 {
|
||||
for _ in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("Took {:?}", start.elapsed());
|
||||
|
Reference in New Issue
Block a user