Fix cargo fmt. (#2383)

* Fix cargo fmt.

* Clippy fix.

* Cosmetic tweaks.
This commit is contained in:
Laurent Mazare
2024-08-01 13:19:41 +01:00
committed by GitHub
parent 2e9c010609
commit 9ca277a9d7
2 changed files with 20 additions and 14 deletions

View File

@ -47,33 +47,39 @@ struct Args {
impl Args { impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
let default = "jinaai/jina-embeddings-v2-base-en".to_string(); let model_name = match self.model.as_ref() {
let model_name = match &self.model { Some(model) => model.to_string(),
Some(model) => model, None => "jinaai/jina-embeddings-v2-base-en".to_string(),
None => &default,
}; };
let model = match &self.model_file { let model = match &self.model_file {
Some(model_file) => std::path::PathBuf::from(model_file), Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()? None => Api::new()?
.repo(Repo::new( .repo(Repo::new(model_name.to_string(), RepoType::Model))
model_name.to_string(),
RepoType::Model,
))
.get("model.safetensors")?, .get("model.safetensors")?,
}; };
let tokenizer = match &self.tokenizer { let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file), Some(file) => std::path::PathBuf::from(file),
None => Api::new()? None => Api::new()?
.repo(Repo::new( .repo(Repo::new(model_name.to_string(), RepoType::Model))
model_name.to_string(),
RepoType::Model,
))
.get("tokenizer.json")?, .get("tokenizer.json")?,
}; };
let device = candle_examples::device(self.cpu)?; let device = candle_examples::device(self.cpu)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi); let config = Config::new(
tokenizer.get_vocab_size(true),
768,
12,
12,
3072,
candle_nn::Activation::Gelu,
8192,
2,
0.02,
1e-12,
0,
PositionEmbeddingType::Alibi,
);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?; let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer)) Ok((model, tokenizer))
@ -124,7 +130,6 @@ fn main() -> anyhow::Result<()> {
println!("normalized_embeddings: {embeddings}"); println!("normalized_embeddings: {embeddings}");
} }
println!("Took {:?}", start.elapsed()); println!("Took {:?}", start.elapsed());
} else { } else {
let sentences = [ let sentences = [
"The cat sits outside", "The cat sits outside",

View File

@ -48,6 +48,7 @@ impl Config {
} }
} }
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
vocab_size: usize, vocab_size: usize,
hidden_size: usize, hidden_size: usize,