Infer the config for llama2-c. (#1208)

This commit is contained in:
Laurent Mazare
2023-10-28 20:00:39 +02:00
committed by GitHub
parent 95a857cf57
commit 012ae0090e
4 changed files with 63 additions and 4 deletions

View File

@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let config = Config::tiny();
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny();
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;

View File

@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);