mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the t5-xxl version. (#924)
This commit is contained in:
@ -8,12 +8,12 @@ use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
@ -25,10 +25,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -80,7 +76,7 @@ struct Args {
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
weights_filename: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
@ -95,28 +91,21 @@ impl T5ModelBuilder {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||
vec![
|
||||
api.get("model-00001-of-00005.safetensors")?,
|
||||
api.get("model-00002-of-00005.safetensors")?,
|
||||
api.get("model-00003-of-00005.safetensors")?,
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
@ -133,18 +122,30 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user