Remove the end of text tokens. (#289)

This commit is contained in:
Laurent Mazare
2023-07-31 20:43:57 +01:00
committed by GitHub
parent 9ae1f6afee
commit 6b98b66eb3
2 changed files with 61 additions and 2 deletions

View File

@ -109,6 +109,46 @@ fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) ->
}
}
fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
data: &[u8],
shape: &[usize],
device: &Device,
conv: F,
) -> Result<Tensor> {
let size_in_bytes = std::mem::size_of::<T>();
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
Tensor::from_vec(data, shape, device)
} else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
let mut c: Vec<T> = Vec::with_capacity(elem_count);
// SAFETY: We just created c, so the allocated memory is necessarily
// contiguous and non overlapping with the view's data.
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
Tensor::from_vec(c, shape, device)
}
}
fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
view: &st::TensorView<'_>,
device: &Device,
conv: F,
) -> Result<Tensor> {
convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
}
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}
@ -158,11 +198,29 @@ impl Tensor {
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
st::Dtype::U16 => {
let conv = |x| Ok(u32::from(x));
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}

View File

@ -266,7 +266,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
let file = std::io::BufReader::new(file);
let mut tokens = vec![];
for line in file.lines() {
let line = tokenizer.encode(line?, false).map_err(E::msg)?;
let line = line?.replace("<|endoftext|>", "");
let line = tokenizer.encode(line, false).map_err(E::msg)?;
tokens.push(line.get_ids().to_vec())
}
let tokens = tokens.concat();