mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Infer the config for llama2-c. (#1208)
This commit is contained in:
@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (model, config) = if is_gguf {
|
let (model, config) = if is_gguf {
|
||||||
let config = Config::tiny();
|
|
||||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||||
|
let (_vocab_size, dim) = vb
|
||||||
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
|
.shape()
|
||||||
|
.dims2()?;
|
||||||
|
let config = match dim {
|
||||||
|
64 => Config::tiny_260k(),
|
||||||
|
288 => Config::tiny_15m(),
|
||||||
|
512 => Config::tiny_42m(),
|
||||||
|
768 => Config::tiny_110m(),
|
||||||
|
_ => anyhow::bail!("no config for dim {dim}"),
|
||||||
|
};
|
||||||
let freq_cis_real = vb
|
let freq_cis_real = vb
|
||||||
.get(
|
.get(
|
||||||
(config.seq_len, config.head_size() / 2),
|
(config.seq_len, config.head_size() / 2),
|
||||||
@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||||
(model, config)
|
(model, config)
|
||||||
} else if is_safetensors {
|
} else if is_safetensors {
|
||||||
let config = Config::tiny();
|
let config = Config::tiny_15m();
|
||||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
);
|
);
|
||||||
let varmap = candle_nn::VarMap::new();
|
let varmap = candle_nn::VarMap::new();
|
||||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny_15m();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
|
@ -17,7 +17,20 @@ pub struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn tiny() -> Self {
|
pub fn tiny_260k() -> Self {
|
||||||
|
Self {
|
||||||
|
dim: 64,
|
||||||
|
hidden_dim: 768,
|
||||||
|
n_layers: 5,
|
||||||
|
n_heads: 8,
|
||||||
|
n_kv_heads: 4,
|
||||||
|
vocab_size: 32000,
|
||||||
|
seq_len: 512,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tiny_15m() -> Self {
|
||||||
Self {
|
Self {
|
||||||
dim: 288,
|
dim: 288,
|
||||||
hidden_dim: 768,
|
hidden_dim: 768,
|
||||||
@ -29,6 +42,32 @@ impl Config {
|
|||||||
norm_eps: 1e-5,
|
norm_eps: 1e-5,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tiny_42m() -> Self {
|
||||||
|
Self {
|
||||||
|
dim: 512,
|
||||||
|
hidden_dim: 768,
|
||||||
|
n_layers: 8,
|
||||||
|
n_heads: 8,
|
||||||
|
n_kv_heads: 8,
|
||||||
|
vocab_size: 32000,
|
||||||
|
seq_len: 1024,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tiny_110m() -> Self {
|
||||||
|
Self {
|
||||||
|
dim: 768,
|
||||||
|
hidden_dim: 768,
|
||||||
|
n_layers: 12,
|
||||||
|
n_heads: 12,
|
||||||
|
n_kv_heads: 12,
|
||||||
|
vocab_size: 32000,
|
||||||
|
seq_len: 1024,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -77,6 +77,16 @@ impl VarBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
|
||||||
|
let path = self.path(name);
|
||||||
|
match self.data.get(&path) {
|
||||||
|
None => {
|
||||||
|
candle::bail!("cannot find tensor {name}")
|
||||||
|
}
|
||||||
|
Some(qtensor) => Ok(qtensor.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user