mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Creating new sync Api for candle-hub
.
- `api::Api` -> `api::tokio::api` (And created new `api::sync::Api`). - Remove `tokio` from all our examples. - Using similar codebase for now instead of ureq (for simplicity).
This commit is contained in:
@ -23,7 +23,6 @@ candle-hub = { path = "../candle-hub" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
rand = "0.8.5"
|
||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||
wav = "1.0.0"
|
||||
|
||||
[features]
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -645,7 +645,7 @@ struct Args {
|
||||
}
|
||||
|
||||
impl Args {
|
||||
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||
let device = if self.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
@ -677,9 +677,9 @@ impl Args {
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
(
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -694,12 +694,11 @@ impl Args {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let args = Args::parse();
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
|
@ -19,7 +19,7 @@ use clap::Parser;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -465,8 +465,7 @@ struct Args {
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
@ -489,13 +488,13 @@ async fn main() -> Result<()> {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
println!("building the model");
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -253,8 +253,7 @@ struct Args {
|
||||
filters: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
@ -276,7 +275,7 @@ async fn main() -> Result<()> {
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path.clone();
|
||||
let mut model_filename = path;
|
||||
model_filename.push("model.safetensors");
|
||||
(
|
||||
config_filename,
|
||||
@ -288,9 +287,9 @@ async fn main() -> Result<()> {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
(
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
if let Some(input) = args.input {
|
||||
std::path::PathBuf::from(input)
|
||||
} else {
|
||||
@ -298,8 +297,7 @@ async fn main() -> Result<()> {
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)
|
||||
.await?
|
||||
)?
|
||||
},
|
||||
)
|
||||
};
|
||||
|
Reference in New Issue
Block a user