diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 348e9a55..55929c33 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -8,12 +8,12 @@ use std::path::PathBuf; use candle_transformers::models::t5; -use anyhow::{anyhow, Error as E, Result}; +use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use clap::Parser; -use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; +use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; @@ -25,10 +25,6 @@ struct Args { #[arg(long)] cpu: bool, - /// Run offline (you must have the files already cached) - #[arg(long)] - offline: bool, - /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -80,7 +76,7 @@ struct Args { struct T5ModelBuilder { device: Device, config: t5::Config, - weights_filename: PathBuf, + weights_filename: Vec, } impl T5ModelBuilder { @@ -95,28 +91,21 @@ impl T5ModelBuilder { (None, None) => (default_model, default_revision), }; - let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = if args.offline { - let cache = Cache::default().repo(repo); - ( - cache - .get("config.json") - .ok_or(anyhow!("Missing config file in cache"))?, - cache - .get("tokenizer.json") - .ok_or(anyhow!("Missing tokenizer file in cache"))?, - cache - .get("model.safetensors") - .ok_or(anyhow!("Missing weights file in cache"))?, - ) + let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision); + let api = Api::new()?; + let api = api.repo(repo); + let config_filename = api.get("config.json")?; + let tokenizer_filename = api.get("tokenizer.json")?; + let weights_filename = if model_id == "google/flan-t5-xxl" { + vec![ + api.get("model-00001-of-00005.safetensors")?, + api.get("model-00002-of-00005.safetensors")?, + api.get("model-00003-of-00005.safetensors")?, + api.get("model-00004-of-00005.safetensors")?, + api.get("model-00005-of-00005.safetensors")?, + ] } else { - let api = Api::new()?; - let api = api.repo(repo); - ( - api.get("config.json")?, - api.get("tokenizer.json")?, - api.get("model.safetensors")?, - ) + vec![api.get("model.safetensors")?] }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?; @@ -133,18 +122,30 @@ impl T5ModelBuilder { } pub fn build_encoder(&self) -> Result { - let weights = - unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + let weights = self + .weights_filename + .iter() + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) + .collect::>>()?; + let weights = weights + .iter() + .map(|w| w.deserialize()) + .collect::>>()?; + let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); Ok(t5::T5EncoderModel::load(vb, &self.config)?) } pub fn build_conditional_generation(&self) -> Result { - let weights = - unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + let weights = self + .weights_filename + .iter() + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) + .collect::>>()?; + let weights = weights + .iter() + .map(|w| w.deserialize()) + .collect::>>()?; + let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } }