mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Merge pull request #92 from LaurentMazare/sync_hub
Creating new sync Api for `candle-hub`.
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -656,7 +656,7 @@ struct Args {
|
||||
}
|
||||
|
||||
impl Args {
|
||||
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||
let device = if self.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
@ -688,9 +688,9 @@ impl Args {
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
(
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -705,12 +705,11 @@ impl Args {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let args = Args::parse();
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
|
Reference in New Issue
Block a user