mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use the hub model file when possible. (#1190)
* Use the hub model file when possible. * And add a mention in the main readme.
This commit is contained in:
@ -35,19 +35,37 @@ struct Args {
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? };
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
Reference in New Issue
Block a user