Merge pull request #92 from LaurentMazare/sync_hub

Creating new sync Api for `candle-hub`.
This commit is contained in:
Nicolas Patry
2023-07-07 00:10:47 +02:00
committed by GitHub
9 changed files with 719 additions and 29 deletions

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
use anyhow::{anyhow, Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_hub::{api::Api, Cache, Repo, RepoType};
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
@ -656,7 +656,7 @@ struct Args {
}
impl Args {
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
let device = if self.cpu {
Device::Cpu
} else {
@ -688,9 +688,9 @@ impl Args {
} else {
let api = Api::new()?;
(
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
@ -705,12 +705,11 @@ impl Args {
}
}
#[tokio::main]
async fn main() -> Result<()> {
fn main() -> Result<()> {
let start = std::time::Instant::now();
let args = Args::parse();
let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {

View File

@ -19,7 +19,7 @@ use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor, D};
use candle_hub::{api::Api, Repo, RepoType};
use candle_hub::{api::sync::Api, Repo, RepoType};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -465,8 +465,7 @@ struct Args {
prompt: Option<String>,
}
#[tokio::main]
async fn main() -> Result<()> {
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
@ -489,13 +488,13 @@ async fn main() -> Result<()> {
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename).await?;
let filename = api.get(&repo, rfilename)?;
filenames.push(filename);
}

View File

@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
use candle_hub::{api::sync::Api, Repo, RepoType};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -253,8 +253,7 @@ struct Args {
filters: String,
}
#[tokio::main]
async fn main() -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
@ -276,7 +275,7 @@ async fn main() -> Result<()> {
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path.clone();
let mut model_filename = path;
model_filename.push("model.safetensors");
(
config_filename,
@ -288,9 +287,9 @@ async fn main() -> Result<()> {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
(
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?,
if let Some(input) = args.input {
std::path::PathBuf::from(input)
} else {
@ -298,8 +297,7 @@ async fn main() -> Result<()> {
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)
.await?
)?
},
)
};