mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Support different mamba models. (#1471)
This commit is contained in:
@ -65,6 +65,8 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
@ -177,6 +179,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Minimal Mamba
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
mod model;
|
||||
use model::{Config, Model};
|
||||
@ -111,6 +111,46 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Mamba130m,
|
||||
Mamba370m,
|
||||
Mamba790m,
|
||||
Mamba1_4b,
|
||||
Mamba2_8b,
|
||||
Mamba2_8bSlimPj,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m
|
||||
| Self::Mamba370m
|
||||
| Self::Mamba790m
|
||||
| Self::Mamba1_4b
|
||||
| Self::Mamba2_8b
|
||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -141,11 +181,14 @@ struct Args {
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "state-spaces/mamba-130m")]
|
||||
model_id: String,
|
||||
#[arg(long, default_value = "mamba130m")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/1")]
|
||||
revision: String,
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
@ -194,9 +237,11 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
|
Reference in New Issue
Block a user