First pass (Quantized scaffolding work done + quantized example scaffolding).

This commit is contained in:
Nicolas Patry
2023-11-01 15:10:11 +01:00
parent 4525b7b52a
commit 2d84c16fed
14 changed files with 247 additions and 62 deletions

View File

@ -16,7 +16,7 @@ struct RmsNorm {
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
let scale = scale.dequantize(scale.device())?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
@ -257,8 +257,8 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
@ -278,22 +278,22 @@ impl ModelWeights {
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight")?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");