mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Some polish.
This commit is contained in:
@ -621,18 +621,26 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
offline: bool,
|
offline: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "This is an example sentence")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
println!("Building {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = if args.cpu {
|
let device = if args.cpu {
|
||||||
@ -672,29 +680,25 @@ async fn main() -> Result<()> {
|
|||||||
api.get(&repo, "model.safetensors").await?,
|
api.get(&repo, "model.safetensors").await?,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
println!("Building {:?}", start.elapsed());
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
println!("Config loaded {:?}", start.elapsed());
|
|
||||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
println!("Tokenizer loaded {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||||
let model = BertModel::load(&vb, &config)?;
|
let model = BertModel::load(&vb, &config)?;
|
||||||
println!("Loaded {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode("This is an example sentence", true)
|
.encode(args.prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for _ in 0..100 {
|
for _ in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
|
Reference in New Issue
Block a user