diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index ea7f70eb..80865304 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -67,6 +67,8 @@ enum Which { Mixtral, #[value(name = "mixtral-instruct")] MixtralInstruct, + #[value(name = "llama3-8b")] + L8b, } impl Which { @@ -82,7 +84,8 @@ impl Which { | Self::L13bCode | Self::L34bCode | Self::Leo7b - | Self::Leo13b => false, + | Self::Leo13b + | Self::L8b => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -116,7 +119,8 @@ impl Which { | Self::Mistral7bInstruct | Self::Mistral7bInstructV02 | Self::OpenChat35 - | Self::Starling7bAlpha => false, + | Self::Starling7bAlpha + | Self::L8b => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } @@ -140,7 +144,8 @@ impl Which { | Self::Mistral7bInstruct | Self::Mistral7bInstructV02 | Self::Zephyr7bAlpha - | Self::Zephyr7bBeta => false, + | Self::Zephyr7bBeta + | Self::L8b => false, Self::OpenChat35 | Self::Starling7bAlpha => true, } } @@ -167,6 +172,7 @@ impl Which { | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", Which::OpenChat35 => "openchat/openchat_3.5", Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", + Self::L8b => "meta-llama/Meta-Llama-3-8B", } } } @@ -322,6 +328,11 @@ impl Args { "TheBloke/Starling-LM-7B-alpha-GGUF", "starling-lm-7b-alpha.Q4_K_M.gguf", ), + // TODO: swap to TheBloke model when available + Which::L8b => ( + "QuantFactory/Meta-Llama-3-8B-GGUF", + "Meta-Llama-3-8B.Q4_K_S.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -420,7 +431,8 @@ fn main() -> anyhow::Result<()> { | Which::L13bCode | Which::L34bCode | Which::Leo7b - | Which::Leo13b => 1, + | Which::Leo13b + | Which::L8b => 1, Which::Mixtral | Which::MixtralInstruct | Which::Mistral7b @@ -537,11 +549,14 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } - let eos_token = if args.which.is_open_chat() { - "<|end_of_turn|>" - } else { - "" + let eos_token = match args.which { + Which::L8b => "<|end_of_text|>", + _ => match args.which.is_open_chat() { + true => "<|end_of_turn|>", + false => "", + }, }; + let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); let start_post_prompt = std::time::Instant::now(); let mut sampled = 0;