Add support for Helium-v1. (#2932)

This commit is contained in:
Laurent Mazare
2025-04-30 19:38:44 +02:00
committed by GitHub
parent 5029ac52bb
commit 38fc86621c

View File

@ -7,7 +7,10 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use candle_transformers::models::helium::{Config, Model}; use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
use candle_transformers::models::llama::{
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
enum Model {
V1 { model: ModelV1, cache: CacheV1 },
Preview(ModelPreview),
}
impl Model {
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
let model = match self {
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
Model::Preview(m) => m.forward(input, start_pos)?,
};
Ok(model)
}
}
#[derive(Debug, Clone)]
enum Config {
V1(ConfigV1),
Preview(ConfigPreview),
}
impl Config {
fn bos_token_id(&self) -> Option<u32> {
match self {
Config::V1(c) => c.bos_token_id,
Config::Preview(c) => Some(c.bos_token_id),
}
}
fn eos_token_id(&self) -> Option<LlamaEosToks> {
match self {
Config::V1(c) => c.eos_token_id.clone(),
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
}
}
}
struct TextGeneration { struct TextGeneration {
model: Model, model: Model,
device: Device, device: Device,
@ -106,7 +147,15 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
generated_tokens += 1; generated_tokens += 1;
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id { let is_eos = self
.config
.eos_token_id()
.as_ref()
.is_some_and(|v| match v {
LlamaEosToks::Single(eos) => *eos == next_token,
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
});
if Some(next_token) == self.config.bos_token_id() || is_eos {
break; break;
} }
if let Some(t) = self.tokenizer.next_token(next_token)? { if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -131,6 +180,8 @@ impl TextGeneration {
enum Which { enum Which {
#[value(name = "v1-preview")] #[value(name = "v1-preview")]
V1Preview, V1Preview,
#[value(name = "v1")]
V1,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -144,9 +195,6 @@ struct Args {
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)] #[arg(long)]
prompt: String, prompt: String,
@ -171,7 +219,7 @@ struct Args {
sample_len: usize, sample_len: usize,
/// The model size to use. /// The model size to use.
#[arg(long, default_value = "v1-preview")] #[arg(long, default_value = "v1")]
which: Which, which: Which,
#[arg(long)] #[arg(long)]
@ -230,6 +278,7 @@ fn main() -> Result<()> {
None => { None => {
let name = match args.which { let name = match args.which {
Which::V1Preview => "kyutai/helium-1-preview-2b", Which::V1Preview => "kyutai/helium-1-preview-2b",
Which::V1 => "kyutai/helium-1-2b",
}; };
name.to_string() name.to_string()
} }
@ -254,18 +303,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config: Config = match args.config { let config_file = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, Some(config_file) => std::path::PathBuf::from(config_file),
None => { None => repo.get("config.json")?,
let config_file = repo.get("config.json")?; };
serde_json::from_slice(&std::fs::read(config_file)?)? let config = match args.which {
} Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
}; };
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (model, device) = { let (model, device) = {
let dtype = device.bf16_default_to_f32(); let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?; let model = match &config {
Config::V1(c) => {
let c = c.clone().into_config(false);
let model = ModelV1::load(vb, &c)?;
let cache = CacheV1::new(true, dtype, &c, &device)?;
Model::V1 { model, cache }
}
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
};
(model, device) (model, device)
}; };