mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Adding offline mode.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
#![allow(dead_code)]
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -617,6 +617,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
@ -627,6 +631,8 @@ struct Args {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
let start = std::time::Instant::now();
|
||||
println!("Building {:?}", start.elapsed());
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
@ -644,21 +650,41 @@ async fn main() -> Result<()> {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
println!("building the model");
|
||||
let config_filename = api.get(&repo, "config.json").await?;
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||
let cache = Cache::default();
|
||||
(
|
||||
cache
|
||||
.get(&repo, "config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get(&repo, "tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get(&repo, "model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
(
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
)
|
||||
};
|
||||
println!("Building {:?}", start.elapsed());
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
println!("Config loaded {:?}", start.elapsed());
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
println!("Tokenizer loaded {:?}", start.elapsed());
|
||||
|
||||
let weights = api.get(&repo, "model.safetensors").await?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||
let model = BertModel::load(&vb, &config)?;
|
||||
println!("Loaded {:?}", start.elapsed());
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode("This is an example sentence", true)
|
||||
@ -666,9 +692,13 @@ async fn main() -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{token_ids}");
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("{ys}");
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for _ in 0..100 {
|
||||
let start = std::time::Instant::now();
|
||||
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("Took {:?}", start.elapsed());
|
||||
// println!("Ys {:?}", ys.shape());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user