mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for Helium-v1. (#2932)
This commit is contained in:
@ -7,7 +7,10 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::helium::{Config, Model};
|
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
|
||||||
|
use candle_transformers::models::llama::{
|
||||||
|
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
|
||||||
|
};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Model {
|
||||||
|
V1 { model: ModelV1, cache: CacheV1 },
|
||||||
|
Preview(ModelPreview),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
||||||
|
let model = match self {
|
||||||
|
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
|
||||||
|
Model::Preview(m) => m.forward(input, start_pos)?,
|
||||||
|
};
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Config {
|
||||||
|
V1(ConfigV1),
|
||||||
|
Preview(ConfigPreview),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn bos_token_id(&self) -> Option<u32> {
|
||||||
|
match self {
|
||||||
|
Config::V1(c) => c.bos_token_id,
|
||||||
|
Config::Preview(c) => Some(c.bos_token_id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eos_token_id(&self) -> Option<LlamaEosToks> {
|
||||||
|
match self {
|
||||||
|
Config::V1(c) => c.eos_token_id.clone(),
|
||||||
|
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -106,7 +147,15 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
|
let is_eos = self
|
||||||
|
.config
|
||||||
|
.eos_token_id()
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|v| match v {
|
||||||
|
LlamaEosToks::Single(eos) => *eos == next_token,
|
||||||
|
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
|
||||||
|
});
|
||||||
|
if Some(next_token) == self.config.bos_token_id() || is_eos {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
@ -131,6 +180,8 @@ impl TextGeneration {
|
|||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "v1-preview")]
|
#[value(name = "v1-preview")]
|
||||||
V1Preview,
|
V1Preview,
|
||||||
|
#[value(name = "v1")]
|
||||||
|
V1,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -144,9 +195,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
@ -171,7 +219,7 @@ struct Args {
|
|||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The model size to use.
|
/// The model size to use.
|
||||||
#[arg(long, default_value = "v1-preview")]
|
#[arg(long, default_value = "v1")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -230,6 +278,7 @@ fn main() -> Result<()> {
|
|||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let name = match args.which {
|
||||||
Which::V1Preview => "kyutai/helium-1-preview-2b",
|
Which::V1Preview => "kyutai/helium-1-preview-2b",
|
||||||
|
Which::V1 => "kyutai/helium-1-2b",
|
||||||
};
|
};
|
||||||
name.to_string()
|
name.to_string()
|
||||||
}
|
}
|
||||||
@ -254,18 +303,27 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config: Config = match args.config {
|
let config_file = match args.config {
|
||||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
Some(config_file) => std::path::PathBuf::from(config_file),
|
||||||
None => {
|
None => repo.get("config.json")?,
|
||||||
let config_file = repo.get("config.json")?;
|
};
|
||||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
let config = match args.which {
|
||||||
}
|
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||||
|
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = {
|
let (model, device) = {
|
||||||
let dtype = device.bf16_default_to_f32();
|
let dtype = device.bf16_default_to_f32();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = match &config {
|
||||||
|
Config::V1(c) => {
|
||||||
|
let c = c.clone().into_config(false);
|
||||||
|
let model = ModelV1::load(vb, &c)?;
|
||||||
|
let cache = CacheV1::new(true, dtype, &c, &device)?;
|
||||||
|
Model::V1 { model, cache }
|
||||||
|
}
|
||||||
|
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
|
||||||
|
};
|
||||||
(model, device)
|
(model, device)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user