From 6b98b66eb36a484f1a65fbc1c528a8e0b90a1419 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 31 Jul 2023 20:43:57 +0100 Subject: [PATCH] Remove the end of text tokens. (#289) --- candle-core/src/safetensors.rs | 60 ++++++++++++++++++++++- candle-examples/examples/llama2-c/main.rs | 3 +- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 0e1cc655..06b9b23b 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -109,6 +109,46 @@ fn convert_slice(data: &[u8], shape: &[usize], device: &Device) -> } } +fn convert_slice_with_cast Result>( + data: &[u8], + shape: &[usize], + device: &Device, + conv: F, +) -> Result { + let size_in_bytes = std::mem::size_of::(); + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; + let data = data.iter().map(|t| conv(*t)).collect::>>()?; + Tensor::from_vec(data, shape, device) + } else { + // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast + // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access + let mut c: Vec = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + let c = c.into_iter().map(conv).collect::>>()?; + Tensor::from_vec(c, shape, device) + } +} + +fn convert_with_cast_ Result>( + view: &st::TensorView<'_>, + device: &Device, + conv: F, +) -> Result { + convert_slice_with_cast::(view.data(), view.shape(), device, conv) +} + fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { convert_slice::(view.data(), view.shape(), device) } @@ -158,11 +198,29 @@ impl Tensor { fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), - st::Dtype::U32 => convert_::(view, device), + st::Dtype::U16 => { + let conv = |x| Ok(u32::from(x)); + convert_with_cast_::(view, device, conv) + } + st::Dtype::U32 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), + st::Dtype::I32 => { + let conv = |x| { + u32::try_from(x) + .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) + }; + convert_with_cast_::(view, device, conv) + } + st::Dtype::I64 => { + let conv = |x| { + u32::try_from(x) + .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) + }; + convert_with_cast_::(view, device, conv) + } dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 65641b3c..d710652f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -266,7 +266,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) let file = std::io::BufReader::new(file); let mut tokens = vec![]; for line in file.lines() { - let line = tokenizer.encode(line?, false).map_err(E::msg)?; + let line = line?.replace("<|endoftext|>", ""); + let line = tokenizer.encode(line, false).map_err(E::msg)?; tokens.push(line.get_ids().to_vec()) } let tokens = tokens.concat();