Olmo 2 model (#2954)

* OLMo 2 model

* Update olmo-2 to example

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jani Monoses
2025-05-14 20:18:02 +03:00
committed by GitHub
parent 6bd61727bc
commit 450a49ed1a
4 changed files with 376 additions and 18 deletions

View File

@ -3,7 +3,7 @@
OLMo is a series of Open Language Models designed to enable the science of language models.
- **Project Page:** https://allenai.org/olmo
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
<!-- - **Press release:** TODO -->

View File

@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::olmo::{Config, Model as OLMo};
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
enum Model {
OLMo(OLMo),
OLMo2(OLMo2),
}
struct TextGeneration {
@ -82,6 +84,7 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::OLMo(m) => m.forward(&input, start_pos)?,
Model::OLMo2(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
@ -129,6 +132,8 @@ enum Which {
W7bTwin2T,
#[value(name = "1.7-7b")]
V1_7W7b,
#[value(name = "2-1b")]
V2W1b,
}
#[derive(Parser, Debug)]
@ -220,6 +225,7 @@ fn main() -> Result<()> {
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
},
};
@ -238,33 +244,36 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
Which::W1b => {
Which::W1b | Which::V2W1b => {
vec![repo.get("model.safetensors")?]
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
let config_filename = repo.get("config.json")?;
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = {
let config_filename = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config
};
let device = candle_examples::device(args.cpu)?;
let model = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.model {
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
}
Which::V2W1b => {
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo2::new(&config, vb)?;
Model::OLMo2(model)
}
};
println!("loaded the model in {:?}", start.elapsed());