diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index a3f01ae2..0ceb27af 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let config = Config::tiny(); let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let (_vocab_size, dim) = vb + .get_no_shape("model.embed_tokens.weight")? + .shape() + .dims2()?; + let config = match dim { + 64 => Config::tiny_260k(), + 288 => Config::tiny_15m(), + 512 => Config::tiny_42m(), + 768 => Config::tiny_110m(), + _ => anyhow::bail!("no config for dim {dim}"), + }; let freq_cis_real = vb .get( (config.seq_len, config.head_size() / 2), @@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); (model, config) } else if is_safetensors { - let config = Config::tiny(); + let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 150a3272..b2aa0889 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); - let config = Config::tiny(); + let config = Config::tiny_15m(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 07a6e2f2..753770fb 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -17,7 +17,20 @@ pub struct Config { } impl Config { - pub fn tiny() -> Self { + pub fn tiny_260k() -> Self { + Self { + dim: 64, + hidden_dim: 768, + n_layers: 5, + n_heads: 8, + n_kv_heads: 4, + vocab_size: 32000, + seq_len: 512, + norm_eps: 1e-5, + } + } + + pub fn tiny_15m() -> Self { Self { dim: 288, hidden_dim: 768, @@ -29,6 +42,32 @@ impl Config { norm_eps: 1e-5, } } + + pub fn tiny_42m() -> Self { + Self { + dim: 512, + hidden_dim: 768, + n_layers: 8, + n_heads: 8, + n_kv_heads: 8, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } + + pub fn tiny_110m() -> Self { + Self { + dim: 768, + hidden_dim: 768, + n_layers: 12, + n_heads: 12, + n_kv_heads: 12, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } } #[derive(Clone)] diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 259496d6..810802e8 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -77,6 +77,16 @@ impl VarBuilder { } } + pub fn get_no_shape(&self, name: &str) -> Result> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => Ok(qtensor.clone()), + } + } + pub fn device(&self) -> &Device { &self.device }