mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add one-hot/cold encoding (#1489)
* add one-hot encoding * one_hot: improve error handling, use generic to_vecN::<D> Bails if the index value is equal to or greater than the depth value, which would result in an out-of-bounds error. A redundant check is added to ensure the index value does not exceed the length of the one-hot matrix size, which would also result in an out-of-bounds error. Bails if the index value is less than -1. If the index value is -1, then it ignores the setting of the on_value for the index value. Only values that are less than -1 are considered errors. * one-hot: use two generics, one_hot::<I, O>, for input and output data types Separating the input and output data types allows the input tensor indices to be a different data type than the output encoded tensor data type. For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values and encode the output tensor using u8 values. The generic I::DTYPE must match the data type of the input indices, otherwise the method will bail. Additionally, this method adds an `allow_f64` option to enable the input indices data type to be f64 values. f64 values are disabled by default. TODO: indices data type and the generic I data type are currently not compile-time checked. * one_hot: remove input generic, use indices dtype matching This commit removes the to_f64() type cast and explicitly matches the DType from the input tensor. Currently, only U8, U32 and I64 is supported for input tensors. The match arms on the dtype is verbose. It would be nice to use a generic type with the WithDtype traitbound to pass to the to_vecN method and then return an inner value. Open to suggestions for better approaches here to reduce the match arm verbosity. * one_hot: use flat_map iterator over dims instead of nested for loop This commit replaces the nested for loops with an flat map iter over the dimensions of the input tensor. This commit also adds a test for a rank 3 input tensor. * one_hot: use mandatory on/off-values, remove const msgs This commit also updates doc tests, comments and test cases. * Small cleanups. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
293
candle-nn/src/encoding.rs
Normal file
293
candle-nn/src/encoding.rs
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
//! Encoding Utilities. (e.g., one-hot/cold encoding)
|
||||||
|
|
||||||
|
use candle::{bail, DType, Result, Tensor, WithDType};
|
||||||
|
|
||||||
|
/// One-hot/cold encoding.
|
||||||
|
///
|
||||||
|
/// Given an input tensor of indices, this function returns a tensor of the same shape as the input
|
||||||
|
/// tensor with an additional dimension of the given depth size. The values in the returned tensor are
|
||||||
|
/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`.
|
||||||
|
///
|
||||||
|
/// This method returns a tensor with a rank that is one rank larger than the input tensor.
|
||||||
|
///
|
||||||
|
/// As an example, the following tensor will be encoded to a one-hot matrix:
|
||||||
|
///
|
||||||
|
/// `[[0i64, 2], [1, -1]]`
|
||||||
|
///
|
||||||
|
/// with a depth of 4 will be encoded to:
|
||||||
|
///
|
||||||
|
/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]`
|
||||||
|
///
|
||||||
|
/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored,
|
||||||
|
/// resulting in a vector of values set to the `off_value`.
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`.
|
||||||
|
/// By default `on_value` is `1` and `off_value` is `0`.
|
||||||
|
///
|
||||||
|
/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ## One-hot encoding
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle::{Shape, Tensor, Device};
|
||||||
|
/// use candle_nn::encoding::one_hot;
|
||||||
|
///
|
||||||
|
/// let device = candle::Device::Cpu;
|
||||||
|
///
|
||||||
|
/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap();
|
||||||
|
/// let depth = 4;
|
||||||
|
/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap();
|
||||||
|
///
|
||||||
|
/// let expected_matrix = [
|
||||||
|
/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
|
||||||
|
/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
||||||
|
/// ];
|
||||||
|
///
|
||||||
|
/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
///
|
||||||
|
/// let matrix = one_hot.to_vec3::<f32>().unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(matrix, expected_matrix);
|
||||||
|
///```
|
||||||
|
/// ## One-cold Encoding
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle::{Shape, Tensor, Device};
|
||||||
|
/// use candle_nn::encoding::one_hot;
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// let device = candle::Device::Cpu;
|
||||||
|
/// let depth = 4;
|
||||||
|
/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap();
|
||||||
|
/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap();
|
||||||
|
///
|
||||||
|
/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]];
|
||||||
|
///
|
||||||
|
/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
///
|
||||||
|
/// let matrix = one_cold.to_vec3::<u8>().unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(matrix, expected_matrix);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// # Bails
|
||||||
|
///
|
||||||
|
/// This method bails if:
|
||||||
|
/// - The input tensor has a rank greater than 3.
|
||||||
|
/// - One of the index value is less than -1.
|
||||||
|
/// - One of the index value is greater than or equal to the depth value.
|
||||||
|
/// - The input data type is not `U8`, `U32`, or `I64`.
|
||||||
|
///
|
||||||
|
/// # API Design
|
||||||
|
///
|
||||||
|
/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method.
|
||||||
|
pub fn one_hot<D: WithDType>(
|
||||||
|
indices: Tensor,
|
||||||
|
depth: usize,
|
||||||
|
on_value: D,
|
||||||
|
off_value: D,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let dtype = indices.dtype();
|
||||||
|
let rank = indices.rank();
|
||||||
|
|
||||||
|
match rank {
|
||||||
|
0 => {
|
||||||
|
let mut v = vec![off_value; depth];
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let vi = indices.to_vec0::<u8>()?;
|
||||||
|
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let vi = indices.to_vec0::<u32>()?;
|
||||||
|
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let vi = indices.to_vec0::<i64>()?;
|
||||||
|
set_int64_value(vi, 0, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
d => unsupported_dtype(d)?,
|
||||||
|
};
|
||||||
|
Tensor::from_vec(v, (depth,), indices.device())
|
||||||
|
}
|
||||||
|
1 => {
|
||||||
|
let dim1 = indices.dims1()?;
|
||||||
|
let mut v = vec![off_value; depth * dim1];
|
||||||
|
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let indices = indices.to_vec1::<i64>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let indices = indices.to_vec1::<i64>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let indices = indices.to_vec1::<i64>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_int64_value(index, i * depth, depth, &mut v, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d => unsupported_dtype(d)?,
|
||||||
|
};
|
||||||
|
Tensor::from_vec(v, (dim1, depth), indices.device())
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
let (dim1, dim2) = indices.dims2()?;
|
||||||
|
let mut v = vec![off_value; depth * dim1 * dim2];
|
||||||
|
let idx = |i: usize, j: usize, depth: usize, dim2: usize| -> usize {
|
||||||
|
i * depth * dim2 + j * depth
|
||||||
|
};
|
||||||
|
let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j)));
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let index = indices.to_vec2::<u8>()?;
|
||||||
|
for (i, j) in iter {
|
||||||
|
set_usize_value(
|
||||||
|
index[i][j] as usize,
|
||||||
|
idx(i, j, depth, dim2),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let index = indices.to_vec2::<u32>()?;
|
||||||
|
for (i, j) in iter {
|
||||||
|
set_usize_value(
|
||||||
|
index[i][j] as usize,
|
||||||
|
idx(i, j, depth, dim2),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let index = indices.to_vec2::<i64>()?;
|
||||||
|
for (i, j) in iter {
|
||||||
|
set_int64_value(
|
||||||
|
index[i][j],
|
||||||
|
idx(i, j, depth, dim2),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d => unsupported_dtype(d)?,
|
||||||
|
};
|
||||||
|
Tensor::from_vec(v, (dim1, dim2, depth), indices.device())
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
let (dim1, dim2, dim3) = indices.dims3()?;
|
||||||
|
let mut v = vec![off_value; depth * dim1 * dim2 * dim3];
|
||||||
|
let idx =
|
||||||
|
|i: usize, j: usize, k: usize, depth: usize, dim2: usize, dim3: usize| -> usize {
|
||||||
|
i * depth * dim2 * dim3 + j * depth * dim3 + k * depth
|
||||||
|
};
|
||||||
|
let iter = (0..dim1)
|
||||||
|
.flat_map(|i| (0..dim2).flat_map(move |j| (0..dim3).map(move |k| (i, j, k))));
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let index = indices.to_vec3::<u8>()?;
|
||||||
|
for (i, j, k) in iter {
|
||||||
|
set_usize_value(
|
||||||
|
index[i][j][k] as usize,
|
||||||
|
idx(i, j, k, depth, dim2, dim3),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let index = indices.to_vec3::<u32>()?;
|
||||||
|
for (i, j, k) in iter {
|
||||||
|
set_usize_value(
|
||||||
|
index[i][j][k] as usize,
|
||||||
|
idx(i, j, k, depth, dim2, dim3),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let index = indices.to_vec3::<i64>()?;
|
||||||
|
for (i, j, k) in iter {
|
||||||
|
set_int64_value(
|
||||||
|
index[i][j][k],
|
||||||
|
idx(i, j, k, depth, dim2, dim3),
|
||||||
|
depth,
|
||||||
|
&mut v,
|
||||||
|
on_value,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d => unsupported_dtype(d)?,
|
||||||
|
};
|
||||||
|
Tensor::from_vec(v, (dim1, dim2, dim3, depth), indices.device())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
bail!("one_hot: rank {} is not supported.", rank)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsupported_dtype(dtype: DType) -> Result<()> {
|
||||||
|
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set unsigned usize index values to the given value.
|
||||||
|
fn set_usize_value<D: WithDType>(
|
||||||
|
value: usize,
|
||||||
|
idx: usize,
|
||||||
|
depth: usize,
|
||||||
|
v: &mut Vec<D>,
|
||||||
|
on_value: D,
|
||||||
|
) -> Result<()> {
|
||||||
|
if value >= depth {
|
||||||
|
bail!("one_hot: index value {value} exceeds depth {depth}")
|
||||||
|
}
|
||||||
|
let idx = idx + value;
|
||||||
|
if idx >= v.len() {
|
||||||
|
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
|
||||||
|
}
|
||||||
|
v[idx] = on_value;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set signed integer index values to the given value.
|
||||||
|
// Signed integer values are only permitted for `-1` values.
|
||||||
|
// Otherwise, the value must be positive for unsigned usize values.
|
||||||
|
// This method will only case i64 values to usize values if the value is positive,
|
||||||
|
// otherwise the method will bail.
|
||||||
|
fn set_int64_value<D: WithDType>(
|
||||||
|
value: i64,
|
||||||
|
idx: usize,
|
||||||
|
depth: usize,
|
||||||
|
v: &mut Vec<D>,
|
||||||
|
on_value: D,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Skip for an entire row of off_values
|
||||||
|
if value == -1 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
if value < -1 {
|
||||||
|
bail!(
|
||||||
|
"one_hot: invalid negative index value {value}, expected a positive index value or -1"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
set_usize_value(value as usize, idx, depth, v, on_value)
|
||||||
|
}
|
@ -2,6 +2,7 @@ pub mod activation;
|
|||||||
pub mod batch_norm;
|
pub mod batch_norm;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
|
pub mod encoding;
|
||||||
pub mod func;
|
pub mod func;
|
||||||
pub mod group_norm;
|
pub mod group_norm;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
|
120
candle-nn/tests/one_hot.rs
Normal file
120
candle-nn/tests/one_hot.rs
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
use candle::{Result, Shape, Tensor};
|
||||||
|
use candle_nn::encoding::one_hot;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_i64_one_hot() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
let depth = 4;
|
||||||
|
|
||||||
|
let on_value = 1.0;
|
||||||
|
let off_value = 0.0;
|
||||||
|
|
||||||
|
let one_hot = one_hot::<f32>(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = [
|
||||||
|
[[1., 0., 0., 0.], [0., 0., 1., 0.]],
|
||||||
|
[[0., 1., 0., 0.], [0., 0., 0., 0.]],
|
||||||
|
];
|
||||||
|
|
||||||
|
assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
|
||||||
|
let matrix = one_hot.to_vec3::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rank_3_one_hot() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
|
||||||
|
let indices = Tensor::new(
|
||||||
|
vec![
|
||||||
|
vec![vec![0i64, 1], vec![2, 3]],
|
||||||
|
vec![vec![3, 1], vec![1, -1]],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let depth = 4;
|
||||||
|
|
||||||
|
let on_value = 1.0;
|
||||||
|
let off_value = 0.0;
|
||||||
|
|
||||||
|
let one_hot = one_hot::<f32>(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = Tensor::new(
|
||||||
|
vec![
|
||||||
|
vec![
|
||||||
|
vec![vec![1f32, 0., 0., 0.], vec![0., 1., 0., 0.]],
|
||||||
|
vec![vec![0., 0., 1., 0.], vec![0., 0., 0., 1.]],
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
vec![vec![0., 0., 0., 1.], vec![0., 1., 0., 0.]],
|
||||||
|
vec![vec![0., 1., 0., 0.], vec![0., 0., 0., 0.]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(one_hot.shape(), expected_matrix.shape());
|
||||||
|
assert_eq!(one_hot.dims(), expected_matrix.dims());
|
||||||
|
|
||||||
|
let matrix = one_hot.get(1)?.to_vec3::<f32>()?;
|
||||||
|
let expected_matrix = expected_matrix.get(1)?.to_vec3::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_u8_one_cold() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
let depth = 4;
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
|
||||||
|
let on_value = 0u8;
|
||||||
|
let off_value = 1;
|
||||||
|
|
||||||
|
// Note that the method does not require the turbofish operator, as the type is inferred from the on_value.
|
||||||
|
let one_cold = one_hot(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 1]]];
|
||||||
|
|
||||||
|
assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
|
||||||
|
let matrix = one_cold.to_vec3::<u8>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_iter() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
let depth = 4;
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
let matrix = indices.to_vec2::<i64>()?;
|
||||||
|
let (dim1, dim2) = indices.dims2()?;
|
||||||
|
|
||||||
|
let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j)));
|
||||||
|
|
||||||
|
let mut v = vec![0; depth * dim1 * dim2];
|
||||||
|
|
||||||
|
for (i, j) in iter {
|
||||||
|
let idx = i * depth * dim2 + j * depth;
|
||||||
|
v[idx] = matrix[i][j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, row) in matrix.iter().enumerate() {
|
||||||
|
for (j, &value) in row.iter().enumerate() {
|
||||||
|
let idx = i * depth * dim2 + j * depth;
|
||||||
|
assert_eq!(v[idx], value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user