mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fix cargo fmt. (#2383)
* Fix cargo fmt. * Clippy fix. * Cosmetic tweaks.
This commit is contained in:
@ -47,33 +47,39 @@ struct Args {
|
|||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
let default = "jinaai/jina-embeddings-v2-base-en".to_string();
|
let model_name = match self.model.as_ref() {
|
||||||
let model_name = match &self.model {
|
Some(model) => model.to_string(),
|
||||||
Some(model) => model,
|
None => "jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||||
None => &default,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let model = match &self.model_file {
|
let model = match &self.model_file {
|
||||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||||
model_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
))
|
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?,
|
||||||
};
|
};
|
||||||
let tokenizer = match &self.tokenizer {
|
let tokenizer = match &self.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||||
model_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
))
|
|
||||||
.get("tokenizer.json")?,
|
.get("tokenizer.json")?,
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi);
|
let config = Config::new(
|
||||||
|
tokenizer.get_vocab_size(true),
|
||||||
|
768,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
3072,
|
||||||
|
candle_nn::Activation::Gelu,
|
||||||
|
8192,
|
||||||
|
2,
|
||||||
|
0.02,
|
||||||
|
1e-12,
|
||||||
|
0,
|
||||||
|
PositionEmbeddingType::Alibi,
|
||||||
|
);
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let model = BertModel::new(vb, &config)?;
|
let model = BertModel::new(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
@ -124,7 +130,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("normalized_embeddings: {embeddings}");
|
println!("normalized_embeddings: {embeddings}");
|
||||||
}
|
}
|
||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let sentences = [
|
let sentences = [
|
||||||
"The cat sits outside",
|
"The cat sits outside",
|
||||||
|
@ -48,6 +48,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
|
Reference in New Issue
Block a user