Compare commits

..

3 Commits

Author SHA1 Message Date
3754b834f4 More prep work for phi. 2024-04-17 10:23:15 +02:00
d79041d94d Rework the MLP bit. 2024-04-17 09:28:50 +02:00
af11b2d461 Prepare for supporting phi-2 properly in the quantized model. 2024-04-17 09:14:38 +02:00
9 changed files with 272 additions and 499 deletions

View File

@ -375,9 +375,9 @@ git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
```
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests

View File

@ -624,7 +624,7 @@ impl Tensor {
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let sigmoid_arg = (*node / arg)?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}

View File

@ -178,8 +178,8 @@ fn mul_mat_vec_via_q8_1(
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
if b_size == 0 || b_size > 4 {
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
@ -204,16 +204,14 @@ fn mul_mat_vec_via_q8_1(
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
let nblocks = if b_size == 1 {
nrows as u32
} else {
(nrows as u32 + 1) / 2
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
@ -400,7 +398,7 @@ impl QCudaStorage {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
4
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,

View File

@ -193,25 +193,17 @@ fn qmm_batch(dev: &Device) -> Result<()> {
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
assert!(diff3 < 1e-4)
} else {
assert_eq!(diff4, 0.0)
assert_eq!(diff3, 0.0)
};
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
assert!(diff3 < 1e-4)
} else {
assert_eq!(diff3, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}

View File

@ -31,8 +31,6 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which {
V1,
V2,
V3,
V3Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -47,8 +45,8 @@ struct Args {
cpu: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -92,11 +90,11 @@ struct Args {
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
@ -120,18 +118,13 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => match args.which {
Which::V3 | Which::V3Instruct => DType::BF16,
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
},
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache, config) = {
let (llama, tokenizer_filename, mut cache) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -145,7 +138,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -153,12 +146,10 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
(Llama::load(vb, &config)?, tokenizer_filename, cache)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -169,7 +160,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;

View File

@ -67,8 +67,8 @@ enum Which {
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
#[value(name = "llama3-8b")]
L8b,
#[value(name = "phi-2")]
Phi2,
}
impl Which {
@ -85,7 +85,7 @@ impl Which {
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::L8b => false,
| Self::Phi2 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -119,8 +119,8 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha
| Self::L8b => false,
| Self::Phi2
| Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -143,36 +143,36 @@ impl Which {
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Phi2
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b => false,
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
Self::Leo7b => "LeoLM/leo-hessianai-7b",
Self::Leo13b => "LeoLM/leo-hessianai-13b",
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Self::OpenChat35 => "openchat/openchat_3.5",
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::Phi2 => "microsoft/phi-2",
}
}
}
@ -328,11 +328,7 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
// TODO: swap to TheBloke model when available
Which::L8b => (
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf",
),
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -432,7 +428,7 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b
| Which::L8b => 1,
| Which::Phi2 => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
@ -549,14 +545,11 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
let eos_token = match args.which {
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",
},
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;

View File

@ -2972,330 +2972,6 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 5
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda5(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<5, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 6
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda6(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<6, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 7
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda7(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<7, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 8
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda8(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<8, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
const int ix = blockDim.x*blockIdx.x + threadIdx.x;

View File

@ -16,8 +16,6 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@ -36,8 +34,6 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
}
}
}
@ -53,8 +49,6 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl Config {
@ -69,8 +63,6 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
@ -85,8 +77,6 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
}
@ -116,6 +106,7 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -175,10 +166,16 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
candle_nn::rotary_emb::rope(x, &cos, &sin)
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
}
fn forward(
@ -196,12 +193,10 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;

View File

@ -1,6 +1,5 @@
use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor};
@ -29,13 +28,13 @@ impl QMatMul {
}
#[derive(Debug, Clone)]
struct Mlp {
struct MlpSilu {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}
impl Module for Mlp {
impl Module for MlpSilu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
@ -45,16 +44,31 @@ impl Module for Mlp {
}
#[derive(Debug, Clone)]
enum MlpOrMoe {
Mlp(Mlp),
struct MlpSimple {
fc1: QMatMul,
fc2: QMatMul,
act: candle_nn::Activation,
}
impl Module for MlpSimple {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
self.fc2.forward(&xs)
}
}
#[derive(Debug, Clone)]
enum Mlp {
Silu(MlpSilu),
Simple(MlpSimple),
MoE {
n_expert_used: usize,
feed_forward_gate_inp: QMatMul,
experts: Vec<Mlp>,
experts: Vec<MlpSilu>,
},
}
impl Module for MlpOrMoe {
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::MoE {
@ -119,20 +133,48 @@ impl Module for MlpOrMoe {
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
Self::Mlp(mlp) => mlp.forward(xs),
Self::Silu(mlp) => mlp.forward(xs),
Self::Simple(mlp) => mlp.forward(xs),
}
}
}
#[derive(Debug, Clone)]
enum Norm {
Rms(crate::quantized_nn::RmsNorm),
Layer(candle_nn::LayerNorm),
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Rms(m) => m.forward(xs),
Self::Layer(m) => m.forward(xs),
}
}
}
fn rms_norm(q: QTensor, eps: f64) -> Result<Norm> {
let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?;
Ok(Norm::Rms(rms))
}
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<Norm> {
let w = w.dequantize(&w.device())?;
let b = b.dequantize(&b.device())?;
let ln = candle_nn::LayerNorm::new(w, b, eps);
Ok(Norm::Layer(ln))
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
mlp_or_moe: MlpOrMoe,
ffn_norm: RmsNorm,
attention_norm: Norm,
mlp: Mlp,
ffn_norm: Norm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
@ -230,7 +272,7 @@ impl LayerWeights {
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
norm: Norm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
@ -256,6 +298,99 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Architecture {
Llama,
Phi2,
}
#[derive(Debug, Clone)]
struct MetadataConfig {
n_expert: usize,
n_expert_used: usize,
head_count: usize,
head_count_kv: usize,
block_count: usize,
embedding_length: usize,
rope_dim: usize,
rms_norm_eps: f64,
rope_freq_base: f32,
architecture: Architecture,
}
impl MetadataConfig {
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let architecture = match md_get("general.architecture")
.and_then(|v| v.to_string())
.map(|v| v.as_str())
{
Ok("phi2") => Architecture::Phi2,
Err(_) | Ok(_) => Architecture::Llama,
};
let config = match architecture {
Architecture::Phi2 => {
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
Self {
n_expert: 1,
n_expert_used: 1,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base: 10_000.,
rope_dim,
rms_norm_eps,
architecture,
}
}
Architecture::Llama => {
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps =
md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
Self {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base,
rope_dim,
rms_norm_eps,
architecture,
}
}
};
Ok(config)
}
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
@ -263,7 +398,7 @@ impl ModelWeights {
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
@ -272,11 +407,11 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let mlp_or_moe = {
let mlp = {
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
MlpOrMoe::Mlp(Mlp {
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
@ -292,9 +427,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
mlp_or_moe,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
attention_norm: rms_norm(attention_norm, 1e-5)?,
mlp,
ffn_norm: rms_norm(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
@ -325,78 +460,71 @@ impl ModelWeights {
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let cfg = MetadataConfig::from_gguf(&ct)?;
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
let norm = rms_norm(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
cfg.rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let mut layers = Vec::with_capacity(cfg.block_count);
for layer_idx in 0..cfg.block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
let mlp = if cfg.n_expert <= 1 {
match cfg.architecture {
Architecture::Llama => {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
Architecture::Phi2 => {
let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let fc2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
Mlp::Simple(MlpSimple {
fc1: QMatMul::from_qtensor(fc1)?,
fc2: QMatMul::from_qtensor(fc2)?,
act: candle_nn::Activation::NewGelu,
})
}
}
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let mut experts = Vec::with_capacity(cfg.n_expert);
for i in 0..cfg.n_expert {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
experts.push(Mlp {
experts.push(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
MlpOrMoe::MoE {
n_expert_used,
Mlp::MoE {
n_expert_used: cfg.n_expert_used,
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
experts,
}
@ -412,12 +540,12 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
mlp_or_moe,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?,
mlp,
ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?,
n_head: cfg.head_count,
n_kv_head: cfg.head_count_kv,
head_dim: cfg.embedding_length / cfg.head_count,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
@ -430,7 +558,7 @@ impl ModelWeights {
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output)?,
@ -473,7 +601,7 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp_or_moe.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = (x + residual)?;
layer_in = x
}