mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Remove the end of text tokens. (#289)
This commit is contained in:
@ -109,6 +109,46 @@ fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) ->
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
|
||||
data: &[u8],
|
||||
shape: &[usize],
|
||||
device: &Device,
|
||||
conv: F,
|
||||
) -> Result<Tensor> {
|
||||
let size_in_bytes = std::mem::size_of::<T>();
|
||||
let elem_count = data.len() / size_in_bytes;
|
||||
if (data.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[T] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
|
||||
let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
|
||||
Tensor::from_vec(data, shape, device)
|
||||
} else {
|
||||
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
|
||||
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
|
||||
let mut c: Vec<T> = Vec::with_capacity(elem_count);
|
||||
// SAFETY: We just created c, so the allocated memory is necessarily
|
||||
// contiguous and non overlapping with the view's data.
|
||||
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
||||
// constraints.
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
|
||||
c.set_len(elem_count)
|
||||
}
|
||||
let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
|
||||
Tensor::from_vec(c, shape, device)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
|
||||
view: &st::TensorView<'_>,
|
||||
device: &Device,
|
||||
conv: F,
|
||||
) -> Result<Tensor> {
|
||||
convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_slice::<T>(view.data(), view.shape(), device)
|
||||
}
|
||||
@ -158,11 +198,29 @@ impl Tensor {
|
||||
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
st::Dtype::U16 => {
|
||||
let conv = |x| Ok(u32::from(x));
|
||||
convert_with_cast_::<u16, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::U32 => convert_::<u32>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i32, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i64, u32, _>(view, device, conv)
|
||||
}
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
|
@ -266,7 +266,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
|
||||
let file = std::io::BufReader::new(file);
|
||||
let mut tokens = vec![];
|
||||
for line in file.lines() {
|
||||
let line = tokenizer.encode(line?, false).map_err(E::msg)?;
|
||||
let line = line?.replace("<|endoftext|>", "");
|
||||
let line = tokenizer.encode(line, false).map_err(E::msg)?;
|
||||
tokens.push(line.get_ids().to_vec())
|
||||
}
|
||||
let tokens = tokens.concat();
|
||||
|
Reference in New Issue
Block a user