mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the t5-xxl version. (#924)
This commit is contained in:
@ -8,12 +8,12 @@ use std::path::PathBuf;
|
|||||||
|
|
||||||
use candle_transformers::models::t5;
|
use candle_transformers::models::t5;
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
@ -25,10 +25,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Run offline (you must have the files already cached)
|
|
||||||
#[arg(long)]
|
|
||||||
offline: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -80,7 +76,7 @@ struct Args {
|
|||||||
struct T5ModelBuilder {
|
struct T5ModelBuilder {
|
||||||
device: Device,
|
device: Device,
|
||||||
config: t5::Config,
|
config: t5::Config,
|
||||||
weights_filename: PathBuf,
|
weights_filename: Vec<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5ModelBuilder {
|
impl T5ModelBuilder {
|
||||||
@ -95,28 +91,21 @@ impl T5ModelBuilder {
|
|||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
let api = Api::new()?;
|
||||||
let cache = Cache::default().repo(repo);
|
let api = api.repo(repo);
|
||||||
(
|
let config_filename = api.get("config.json")?;
|
||||||
cache
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
.get("config.json")
|
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
vec![
|
||||||
cache
|
api.get("model-00001-of-00005.safetensors")?,
|
||||||
.get("tokenizer.json")
|
api.get("model-00002-of-00005.safetensors")?,
|
||||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
api.get("model-00003-of-00005.safetensors")?,
|
||||||
cache
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
.get("model.safetensors")
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
]
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
let api = Api::new()?;
|
vec![api.get("model.safetensors")?]
|
||||||
let api = api.repo(repo);
|
|
||||||
(
|
|
||||||
api.get("config.json")?,
|
|
||||||
api.get("tokenizer.json")?,
|
|
||||||
api.get("model.safetensors")?,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
@ -133,18 +122,30 @@ impl T5ModelBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||||
let weights =
|
let weights = self
|
||||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
.weights_filename
|
||||||
let weights = weights.deserialize()?;
|
.iter()
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
let weights = weights
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.deserialize())
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
let weights =
|
let weights = self
|
||||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
.weights_filename
|
||||||
let weights = weights.deserialize()?;
|
.iter()
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
let weights = weights
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.deserialize())
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user