mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
3 Commits
llama-v3-m
...
phi2-gguf
Author | SHA1 | Date | |
---|---|---|---|
3754b834f4 | |||
d79041d94d | |||
af11b2d461 |
@ -375,9 +375,9 @@ git submodule update --init
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
|
@ -624,7 +624,7 @@ impl Tensor {
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
|
@ -178,8 +178,8 @@ fn mul_mat_vec_via_q8_1(
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 8 {
|
||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
@ -204,16 +204,14 @@ fn mul_mat_vec_via_q8_1(
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
@ -400,7 +398,7 @@ impl QCudaStorage {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
8
|
||||
4
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
|
@ -193,25 +193,17 @@ fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||
let mm4 = rhs.forward(&lhs4)?;
|
||||
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||
// the difference here.
|
||||
assert!(0. < diff4 && diff4 < 1e-4)
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff4, 0.0)
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff4, 0.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -31,8 +31,6 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V3Instruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -47,8 +45,8 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -92,11 +90,11 @@ struct Args {
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 128)]
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
@ -120,18 +118,13 @@ fn main() -> Result<()> {
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => match args.which {
|
||||
Which::V3 | Which::V3Instruct => DType::BF16,
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
|
||||
},
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -145,7 +138,7 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -153,12 +146,10 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = config
|
||||
.eos_token_id
|
||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -169,7 +160,7 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
@ -67,8 +67,8 @@ enum Which {
|
||||
Mixtral,
|
||||
#[value(name = "mixtral-instruct")]
|
||||
MixtralInstruct,
|
||||
#[value(name = "llama3-8b")]
|
||||
L8b,
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -85,7 +85,7 @@ impl Which {
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::L8b => false,
|
||||
| Self::Phi2 => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -119,8 +119,8 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::OpenChat35
|
||||
| Self::Starling7bAlpha
|
||||
| Self::L8b => false,
|
||||
| Self::Phi2
|
||||
| Self::Starling7bAlpha => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
@ -143,36 +143,36 @@ impl Which {
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Phi2
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b => false,
|
||||
| Self::Zephyr7bBeta => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn tokenizer_repo(&self) -> &'static str {
|
||||
match self {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L70b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L70bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Mistral7bInstructV02
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||
Self::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||
Self::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::Phi2 => "microsoft/phi-2",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -328,11 +328,7 @@ impl Args {
|
||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
// TODO: swap to TheBloke model when available
|
||||
Which::L8b => (
|
||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -432,7 +428,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b
|
||||
| Which::L8b => 1,
|
||||
| Which::Phi2 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
| Which::Mistral7b
|
||||
@ -549,14 +545,11 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = match args.which {
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
false => "</s>",
|
||||
},
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
};
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
|
@ -2972,330 +2972,6 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 5
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 6
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 7
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 8
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -16,8 +16,6 @@ pub struct LlamaConfig {
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
@ -36,8 +34,6 @@ impl LlamaConfig {
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -53,8 +49,6 @@ pub struct Config {
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -69,8 +63,6 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,8 +77,6 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -116,6 +106,7 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -175,10 +166,16 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
@ -196,12 +193,10 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
@ -1,6 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
@ -29,13 +28,13 @@ impl QMatMul {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
struct MlpSilu {
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
impl Module for MlpSilu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||
@ -45,16 +44,31 @@ impl Module for Mlp {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum MlpOrMoe {
|
||||
Mlp(Mlp),
|
||||
struct MlpSimple {
|
||||
fc1: QMatMul,
|
||||
fc2: QMatMul,
|
||||
act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Module for MlpSimple {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Mlp {
|
||||
Silu(MlpSilu),
|
||||
Simple(MlpSimple),
|
||||
MoE {
|
||||
n_expert_used: usize,
|
||||
feed_forward_gate_inp: QMatMul,
|
||||
experts: Vec<Mlp>,
|
||||
experts: Vec<MlpSilu>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Module for MlpOrMoe {
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::MoE {
|
||||
@ -119,20 +133,48 @@ impl Module for MlpOrMoe {
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
Self::Mlp(mlp) => mlp.forward(xs),
|
||||
Self::Silu(mlp) => mlp.forward(xs),
|
||||
Self::Simple(mlp) => mlp.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Norm {
|
||||
Rms(crate::quantized_nn::RmsNorm),
|
||||
Layer(candle_nn::LayerNorm),
|
||||
}
|
||||
|
||||
impl Module for Norm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Rms(m) => m.forward(xs),
|
||||
Self::Layer(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rms_norm(q: QTensor, eps: f64) -> Result<Norm> {
|
||||
let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?;
|
||||
Ok(Norm::Rms(rms))
|
||||
}
|
||||
|
||||
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<Norm> {
|
||||
let w = w.dequantize(&w.device())?;
|
||||
let b = b.dequantize(&b.device())?;
|
||||
let ln = candle_nn::LayerNorm::new(w, b, eps);
|
||||
Ok(Norm::Layer(ln))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
attention_norm: RmsNorm,
|
||||
mlp_or_moe: MlpOrMoe,
|
||||
ffn_norm: RmsNorm,
|
||||
attention_norm: Norm,
|
||||
mlp: Mlp,
|
||||
ffn_norm: Norm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
@ -230,7 +272,7 @@ impl LayerWeights {
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
norm: Norm,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
@ -256,6 +298,99 @@ fn precomput_freqs_cis(
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Architecture {
|
||||
Llama,
|
||||
Phi2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MetadataConfig {
|
||||
n_expert: usize,
|
||||
n_expert_used: usize,
|
||||
head_count: usize,
|
||||
head_count_kv: usize,
|
||||
block_count: usize,
|
||||
embedding_length: usize,
|
||||
rope_dim: usize,
|
||||
rms_norm_eps: f64,
|
||||
rope_freq_base: f32,
|
||||
architecture: Architecture,
|
||||
}
|
||||
|
||||
impl MetadataConfig {
|
||||
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let architecture = match md_get("general.architecture")
|
||||
.and_then(|v| v.to_string())
|
||||
.map(|v| v.as_str())
|
||||
{
|
||||
Ok("phi2") => Architecture::Phi2,
|
||||
Err(_) | Ok(_) => Architecture::Llama,
|
||||
};
|
||||
|
||||
let config = match architecture {
|
||||
Architecture::Phi2 => {
|
||||
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
|
||||
Self {
|
||||
n_expert: 1,
|
||||
n_expert_used: 1,
|
||||
head_count,
|
||||
head_count_kv,
|
||||
block_count,
|
||||
embedding_length,
|
||||
rope_freq_base: 10_000.,
|
||||
rope_dim,
|
||||
rms_norm_eps,
|
||||
architecture,
|
||||
}
|
||||
}
|
||||
Architecture::Llama => {
|
||||
let n_expert = md_get("llama.expert_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let n_expert_used = md_get("llama.expert_used_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps =
|
||||
md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
Self {
|
||||
n_expert,
|
||||
n_expert_used,
|
||||
head_count,
|
||||
head_count_kv,
|
||||
block_count,
|
||||
embedding_length,
|
||||
rope_freq_base,
|
||||
rope_dim,
|
||||
rms_norm_eps,
|
||||
architecture,
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
@ -263,7 +398,7 @@ impl ModelWeights {
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
for layer_idx in 0..ct.hparams.n_layer {
|
||||
@ -272,11 +407,11 @@ impl ModelWeights {
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
let mlp_or_moe = {
|
||||
let mlp = {
|
||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
@ -292,9 +427,9 @@ impl ModelWeights {
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
||||
attention_norm: rms_norm(attention_norm, 1e-5)?,
|
||||
mlp,
|
||||
ffn_norm: rms_norm(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
@ -325,78 +460,71 @@ impl ModelWeights {
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
let cfg = MetadataConfig::from_gguf(&ct)?;
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let n_expert = md_get("llama.expert_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let n_expert_used = md_get("llama.expert_used_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
let norm = rms_norm(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let mut layers = Vec::with_capacity(cfg.block_count);
|
||||
for layer_idx in 0..cfg.block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let mlp_or_moe = if n_expert <= 1 {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
let mlp = if cfg.n_expert <= 1 {
|
||||
match cfg.architecture {
|
||||
Architecture::Llama => {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
Architecture::Phi2 => {
|
||||
let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let fc2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
Mlp::Simple(MlpSimple {
|
||||
fc1: QMatMul::from_qtensor(fc1)?,
|
||||
fc2: QMatMul::from_qtensor(fc2)?,
|
||||
act: candle_nn::Activation::NewGelu,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let feed_forward_gate_inp =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
||||
let mut experts = Vec::with_capacity(n_expert);
|
||||
for i in 0..n_expert {
|
||||
let mut experts = Vec::with_capacity(cfg.n_expert);
|
||||
for i in 0..cfg.n_expert {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||
experts.push(Mlp {
|
||||
experts.push(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
MlpOrMoe::MoE {
|
||||
n_expert_used,
|
||||
Mlp::MoE {
|
||||
n_expert_used: cfg.n_expert_used,
|
||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||
experts,
|
||||
}
|
||||
@ -412,12 +540,12 @@ impl ModelWeights {
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?,
|
||||
mlp,
|
||||
ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?,
|
||||
n_head: cfg.head_count,
|
||||
n_kv_head: cfg.head_count_kv,
|
||||
head_dim: cfg.embedding_length / cfg.head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
@ -430,7 +558,7 @@ impl ModelWeights {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
@ -473,7 +601,7 @@ impl ModelWeights {
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp_or_moe.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
layer_in = x
|
||||
}
|
||||
|
Reference in New Issue
Block a user