Add the t5-xxl version. (#924)

This commit is contained in:
Laurent Mazare
2023-09-21 14:48:13 +01:00
committed by GitHub
parent b43ca493f6
commit aa8ec06fd2

View File

@ -8,12 +8,12 @@ use std::path::PathBuf;
use candle_transformers::models::t5; use candle_transformers::models::t5;
use anyhow::{anyhow, Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use clap::Parser; use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -25,10 +25,6 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file). /// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
@ -80,7 +76,7 @@ struct Args {
struct T5ModelBuilder { struct T5ModelBuilder {
device: Device, device: Device,
config: t5::Config, config: t5::Config,
weights_filename: PathBuf, weights_filename: Vec<PathBuf>,
} }
impl T5ModelBuilder { impl T5ModelBuilder {
@ -95,28 +91,21 @@ impl T5ModelBuilder {
(None, None) => (default_model, default_revision), (None, None) => (default_model, default_revision),
}; };
let repo = Repo::with_revision(model_id, RepoType::Model, revision); let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if args.offline { let api = Api::new()?;
let cache = Cache::default().repo(repo); let api = api.repo(repo);
( let config_filename = api.get("config.json")?;
cache let tokenizer_filename = api.get("tokenizer.json")?;
.get("config.json") let weights_filename = if model_id == "google/flan-t5-xxl" {
.ok_or(anyhow!("Missing config file in cache"))?, vec![
cache api.get("model-00001-of-00005.safetensors")?,
.get("tokenizer.json") api.get("model-00002-of-00005.safetensors")?,
.ok_or(anyhow!("Missing tokenizer file in cache"))?, api.get("model-00003-of-00005.safetensors")?,
cache api.get("model-00004-of-00005.safetensors")?,
.get("model.safetensors") api.get("model-00005-of-00005.safetensors")?,
.ok_or(anyhow!("Missing weights file in cache"))?, ]
)
} else { } else {
let api = Api::new()?; vec![api.get("model.safetensors")?]
let api = api.repo(repo);
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
}; };
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?; let mut config: t5::Config = serde_json::from_str(&config)?;
@ -133,18 +122,30 @@ impl T5ModelBuilder {
} }
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> { pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let weights = let weights = self
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; .weights_filename
let weights = weights.deserialize()?; .iter()
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5EncoderModel::load(vb, &self.config)?) Ok(t5::T5EncoderModel::load(vb, &self.config)?)
} }
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> { pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let weights = let weights = self
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; .weights_filename
let weights = weights.deserialize()?; .iter()
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
} }
} }