From 3071134788334c972d9e356f53887d2b2ff026b7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 16 Aug 2023 12:41:07 +0100 Subject: [PATCH] Get the ggml based llama to generate some text. (#464) * Add more stats to the ggml example. * Build a quantized model from the file content. * Move the tensor retrieval in the main crate. * Start adding the forward pass. * Add more to the forward pass of the quantized llama. * Apply the attention layers. * Add the sampling loop. * Get the sampling loop to work. * Minor tweak. * Add a quantize/dequantize test. * Bugfix. * Add a comment + swap the order. * Bugfixes. --- candle-core/src/error.rs | 4 + candle-core/src/quantized/ggml_file.rs | 18 +- candle-core/src/quantized/k_quants.rs | 13 +- candle-core/src/quantized/mod.rs | 31 ++- candle-core/tests/quantized_tests.rs | 35 ++- candle-examples/examples/ggml/main.rs | 311 +++++++++++++++++++++++- candle-examples/examples/llama/model.rs | 6 +- 7 files changed, 381 insertions(+), 37 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index c18b43c6..1cf20a84 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -210,6 +210,10 @@ impl Error { Self::Wrapped(Box::new(err)) } + pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { + Self::Msg(err.to_string()) + } + pub fn bt(self) -> Self { let backtrace = std::backtrace::Backtrace::capture(); match backtrace.status() { diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index ee23cdde..7afb8670 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -3,6 +3,7 @@ use super::{k_quants, GgmlDType}; use crate::Result; use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -163,6 +164,9 @@ fn read_one_tensor( let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?; let mut dims = vec![0u32; n_dims as usize]; reader.read_u32_into::(&mut dims)?; + // The dimensions are stored in reverse order, see for example: + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969 + dims.reverse(); let mut name = vec![0u8; name_len as usize]; reader.read_exact(&mut name)?; let name = String::from_utf8_lossy(&name).into_owned(); @@ -174,7 +178,6 @@ fn read_one_tensor( let dims = dims.iter().map(|&u| u as usize).collect::>(); let tensor_elems = dims.iter().product::(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); - println!("{name} {ggml_dtype:?} {dims:?}"); // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; @@ -188,7 +191,7 @@ pub struct Content { pub magic: VersionedMagic, pub hparams: HParams, pub vocab: Vocab, - pub tensors: Vec<(String, super::QTensor)>, + pub tensors: HashMap, } impl Content { @@ -199,11 +202,11 @@ impl Content { let magic = VersionedMagic::read(reader)?; let hparams = HParams::read(reader)?; let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; - let mut tensors = vec![]; + let mut tensors = HashMap::new(); while reader.stream_position()? != last_position { let (name, tensor) = read_one_tensor(reader, magic)?; - tensors.push((name, tensor)) + tensors.insert(name, tensor); } Ok(Self { magic, @@ -212,4 +215,11 @@ impl Content { tensors, }) } + + pub fn remove(&mut self, name: &str) -> Result { + match self.tensors.remove(name) { + None => crate::bail!("cannot find tensor with name '{name}'"), + Some(tensor) => Ok(tensor), + } + } } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 53f3dc65..f7611897 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -531,20 +531,21 @@ impl GgmlType for BlockQ4_0 { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { let k = ys.len(); - if k % QK4_0 != 0 { - crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}") + let qk = Self::BLCK_SIZE; + if k % qk != 0 { + crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}") } - let nb = k / QK4_0; + let nb = k / qk; for i in 0..nb { let d = xs[i].d.to_f32(); - for j in 0..(QK4_0 / 2) { + for j in 0..(qk / 2) { let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8; let x1 = (xs[i].qs[j] >> 4) as i16 - 8; - ys[i * QK4_0 + j] = (x0 as f32) * d; - ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d; + ys[i * qk + j] = (x0 as f32) * d; + ys[i * qk + j + qk / 2] = (x1 as f32) * d; } } Ok(()) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 842b519b..52dddcf5 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -50,7 +50,8 @@ impl GgmlDType { Ok(dtype) } - fn type_size(&self) -> usize { + /// The type size for blocks in bytes. + pub fn type_size(&self) -> usize { use k_quants::*; match self { Self::F32 => 4, @@ -71,7 +72,8 @@ impl GgmlDType { } } - fn blck_size(&self) -> usize { + /// The block size, i.e. the number of elements stored in each block. + pub fn blck_size(&self) -> usize { match self { Self::F32 => 1, Self::F16 => 1, @@ -143,16 +145,15 @@ impl QTensor { } } -#[derive(Debug, Clone)] -pub struct QMatMul(std::sync::Arc); +pub struct QMatMul(std::sync::Arc>); impl QMatMul { - pub fn new(qtensor: std::sync::Arc) -> Self { - Self(qtensor) + pub fn from_qtensor(qtensor: QTensor) -> Self { + Self(std::sync::Arc::new(Box::new(qtensor))) } } -impl crate::CustomOp1 for QMatMul { +impl crate::CustomOp1 for QTensor { fn name(&self) -> &'static str { "qmatmul" } @@ -166,17 +167,15 @@ impl crate::CustomOp1 for QMatMul { crate::bail!("input tensor is not contiguous {layout:?}") } let src_shape = layout.shape(); - let (k, n) = self.0.shape.dims2()?; + // self is transposed so n is first then k. + let (n, k) = self.shape.dims2()?; if src_shape.rank() < 2 { crate::bail!("input tensor has only one dimension {layout:?}") } let mut dst_shape = src_shape.dims().to_vec(); let last_k = dst_shape.pop().unwrap(); if last_k != k { - crate::bail!( - "input tensor {layout:?} incompatible with {:?}", - self.0.shape - ) + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); @@ -184,7 +183,7 @@ impl crate::CustomOp1 for QMatMul { let storage = &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()]; - self.0.matmul_t( + self.matmul_t( (dst_shape.elem_count() / n, k, n), storage, &mut dst_storage, @@ -192,3 +191,9 @@ impl crate::CustomOp1 for QMatMul { Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } } + +impl QMatMul { + pub fn forward(&self, xs: &Tensor) -> Result { + xs.custom_op1_arc(self.0.clone()) + } +} diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 2c05abb4..babd71a8 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -30,9 +30,9 @@ fn quantized_matmul() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (64, 4)); - let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor)); - let res = tensor_lhs.custom_op1(op)?; + let qtensor = quantized::QTensor::new(rhs_t, (4, 64)); + let matmul = quantized::QMatMul::from_qtensor(qtensor); + let res = matmul.forward(&tensor_lhs)?; assert_eq!( res.to_vec2::()?, &[ @@ -44,3 +44,32 @@ fn quantized_matmul() -> Result<()> { Ok(()) } + +#[test] +fn quantize_q4_0() -> Result<()> { + use k_quants::BlockQ4_0; + + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let mut dst = vec![0f32; 32 * 4]; + let mut quant = vec![BlockQ4_0::zeros(); 4]; + BlockQ4_0::from_float(&src, &mut quant)?; + BlockQ4_0::to_float(&quant, dst.as_mut_slice())?; + assert_eq!( + dst, + &[ + -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, + 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, + 23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375, + 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25, + 47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, + 55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25, + 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125, + 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0, + 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0, + 127.0, 127.0 + ] + ); + Ok(()) +} diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 9e3e1ba6..912bc53a 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -1,8 +1,236 @@ -use anyhow::Result; +#![allow(dead_code)] use clap::Parser; -use std::fs::File; +use std::collections::HashMap; +use std::io::Write; use candle::quantized::ggml_file::Content; +use candle::quantized::{QMatMul, QTensor}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::Embedding; +use candle_transformers::generation::LogitsProcessor; + +const MAX_SEQ_LEN: usize = 4096; +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +struct RmsNorm { + scale: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(scale: QTensor) -> Result { + let scale = scale.dequantize(&Device::Cpu)?; + Ok(Self { scale, eps: 1e-5 }) + } + + fn forward(&self, x: &Tensor) -> Result { + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; + let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; + let size = self.scale.dims1()?; + let scale = self + .scale + .to_dtype(DType::F32)? + .broadcast_as((b_sz, seq_len, size))?; + let x = (scale * x_normed)?; + Ok(x) + } +} + +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + attention_norm: RmsNorm, + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, + ffn_norm: RmsNorm, + n_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let (b_sz, _, seq_len, n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; + let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; + Ok(rope) + } + + fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result { + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + // TODO: KV cache. + + // If we start supporting MQA, we need to repeat the k and v tensors here. + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = mask.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } +} + +struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec, + norm: RmsNorm, + // TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K. + output: candle_nn::Linear, + masks: HashMap, +} + +struct WeightMap(HashMap); +impl WeightMap { + fn get(&mut self, name: &str) -> Result { + match self.0.remove(name) { + None => candle::bail!("cannot find tensor with name '{name}'"), + Some(tensor) => Ok(tensor), + } + } +} + +impl ModelWeights { + fn new(mut ct: Content) -> Result { + let cpu = &Device::Cpu; + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + + // precompute freqs_cis + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?)?; + let output = ct.remove("output.weight")?; + let output = candle_nn::Linear::new(output.dequantize(cpu)?, None); + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("layers.{layer_idx}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm)?, + n_head: ct.hparams.n_head as usize, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + }) + } + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output, + masks: HashMap::new(), + }) + } + + fn mask(&mut self, t: usize) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let mask = self.mask(seq_len)?; + let mut layer_in = self.tok_embeddings.forward(x)?; + for (_layer_idx, layer) in self.layers.iter().enumerate() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, &mask, index_pos)?; + let x = (attn + residual)?; + + // MLP + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let w1 = layer.feed_forward_w1.forward(&x)?; + let w3 = layer.feed_forward_w3.forward(&x)?; + let mlp = layer + .feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; + layer_in = (mlp + residual)?; + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + self.output.forward(&x) + } +} #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -10,19 +238,90 @@ struct Args { /// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp #[arg(long)] model: String, + + /// The initial prompt. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, } -fn main() -> Result<()> { +fn main() -> anyhow::Result<()> { + use tokenizers::Tokenizer; let args = Args::parse(); - let mut file = File::open(args.model)?; + let mut file = std::fs::File::open(args.model)?; let start = std::time::Instant::now(); let model = Content::read(&mut file)?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensors.iter() { + let elem_count = tensor.shape().elem_count(); + total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + } + let total_size = if total_size_in_bytes < 1_000 { + format!("{}B", total_size_in_bytes) + } else if total_size_in_bytes < 1_000_000 { + format!("{:.2}KB", total_size_in_bytes as f64 / 1e3) + } else if total_size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", total_size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", total_size_in_bytes as f64 / 1e9) + }; + println!( - "Loaded {:?} tensors in {:?}", + "loaded {:?} tensors ({}) in {:.2}s", model.tensors.len(), - start.elapsed() + total_size, + start.elapsed().as_secs_f32(), ); + println!("params: {:?}", model.hparams); + let mut model = ModelWeights::new(model)?; + println!("model built"); + + let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?; + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(anyhow::Error::msg)? + .get_ids() + .to_vec(); + let mut index_pos = 0; + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + for _index in 0..args.sample_len { + let context_size = tokens.len(); + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?; + let logits = model.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + + // Extracting the last token as a string is complicated, here we just apply some simple + // heuristics as it seems to work well enough for this example. See the following for more + // details: + // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 + if let Some(text) = tokenizer.id_to_token(next_token) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } + } Ok(()) } diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 0da3697f..940c980c 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -140,10 +140,6 @@ impl Cache { } } -fn silu(xs: &Tensor) -> Result { - xs / (xs.neg()?.exp()? + 1.0)? -} - fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "linear"); let inner = candle_nn::linear_no_bias(size1, size2, vb)?; @@ -358,7 +354,7 @@ struct Mlp { impl Mlp { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) }