mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Generic implementation of vecdot for q80. (#596)
* Generic implementation of vecdot for q80. * Add support for code-llama 7b. * Support more code-llama.
This commit is contained in:
@ -195,10 +195,10 @@ impl WeightMap {
|
||||
}
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize) -> Result<(Tensor, Tensor)> {
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
@ -214,7 +214,7 @@ impl ModelWeights {
|
||||
fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim)?;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
@ -287,7 +287,10 @@ impl ModelWeights {
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
|
||||
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim)?;
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
@ -399,6 +402,12 @@ enum Which {
|
||||
L13bChat,
|
||||
#[value(name = "70b-chat")]
|
||||
L70bChat,
|
||||
#[value(name = "7b-code")]
|
||||
L7bCode,
|
||||
#[value(name = "13b-code")]
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -486,6 +495,9 @@ impl Args {
|
||||
"TheBloke/Llama-2-70B-Chat-GGML",
|
||||
"llama-2-70b-chat.ggmlv3.q4_0.bin",
|
||||
),
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -607,7 +619,13 @@ fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
println!("params: {:?}", model.hparams);
|
||||
let default_gqa = match args.which {
|
||||
Which::L7b | Which::L13b | Which::L7bChat | Which::L13bChat => 1,
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
|
Reference in New Issue
Block a user