mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Simplify the one-hot implementation, support arbitrary rank. (#1514)
* Simplify the one-hot implementation, support arbitrary rank. * More cleanup.
This commit is contained in:
@ -77,7 +77,6 @@ use candle::{bail, DType, Result, Tensor, WithDType};
|
||||
/// # Bails
|
||||
///
|
||||
/// This method bails if:
|
||||
/// - The input tensor has a rank greater than 3.
|
||||
/// - One of the index value is less than -1.
|
||||
/// - One of the index value is greater than or equal to the depth value.
|
||||
/// - The input data type is not `U8`, `U32`, or `I64`.
|
||||
@ -91,195 +90,44 @@ pub fn one_hot<D: WithDType>(
|
||||
on_value: D,
|
||||
off_value: D,
|
||||
) -> Result<Tensor> {
|
||||
let dtype = indices.dtype();
|
||||
let rank = indices.rank();
|
||||
|
||||
match rank {
|
||||
0 => {
|
||||
let mut v = vec![off_value; depth];
|
||||
match dtype {
|
||||
DType::U8 => {
|
||||
let vi = indices.to_vec0::<u8>()?;
|
||||
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
|
||||
}
|
||||
DType::U32 => {
|
||||
let vi = indices.to_vec0::<u32>()?;
|
||||
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
|
||||
}
|
||||
DType::I64 => {
|
||||
let vi = indices.to_vec0::<i64>()?;
|
||||
set_int64_value(vi, 0, depth, &mut v, on_value)?;
|
||||
}
|
||||
d => unsupported_dtype(d)?,
|
||||
};
|
||||
Tensor::from_vec(v, (depth,), indices.device())
|
||||
let mut target_shape = indices.dims().to_vec();
|
||||
target_shape.push(depth);
|
||||
let indices = indices.flatten_all()?;
|
||||
let mut out = vec![off_value; depth * indices.elem_count()];
|
||||
match indices.dtype() {
|
||||
DType::U8 => {
|
||||
let indices = indices.to_vec1::<u8>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||
}
|
||||
}
|
||||
1 => {
|
||||
let dim1 = indices.dims1()?;
|
||||
let mut v = vec![off_value; depth * dim1];
|
||||
|
||||
match dtype {
|
||||
DType::U8 => {
|
||||
let indices = indices.to_vec1::<i64>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
let indices = indices.to_vec1::<i64>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
let indices = indices.to_vec1::<i64>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_int64_value(index, i * depth, depth, &mut v, on_value)?;
|
||||
}
|
||||
}
|
||||
d => unsupported_dtype(d)?,
|
||||
};
|
||||
Tensor::from_vec(v, (dim1, depth), indices.device())
|
||||
DType::U32 => {
|
||||
let indices = indices.to_vec1::<u32>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
let (dim1, dim2) = indices.dims2()?;
|
||||
let mut v = vec![off_value; depth * dim1 * dim2];
|
||||
let idx = |i: usize, j: usize, depth: usize, dim2: usize| -> usize {
|
||||
i * depth * dim2 + j * depth
|
||||
};
|
||||
let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j)));
|
||||
match dtype {
|
||||
DType::U8 => {
|
||||
let index = indices.to_vec2::<u8>()?;
|
||||
for (i, j) in iter {
|
||||
set_usize_value(
|
||||
index[i][j] as usize,
|
||||
idx(i, j, depth, dim2),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
let index = indices.to_vec2::<u32>()?;
|
||||
for (i, j) in iter {
|
||||
set_usize_value(
|
||||
index[i][j] as usize,
|
||||
idx(i, j, depth, dim2),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
let index = indices.to_vec2::<i64>()?;
|
||||
for (i, j) in iter {
|
||||
set_int64_value(
|
||||
index[i][j],
|
||||
idx(i, j, depth, dim2),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
d => unsupported_dtype(d)?,
|
||||
};
|
||||
Tensor::from_vec(v, (dim1, dim2, depth), indices.device())
|
||||
DType::I64 => {
|
||||
let indices = indices.to_vec1::<i64>()?;
|
||||
for (i, &index) in indices.iter().enumerate() {
|
||||
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||
}
|
||||
}
|
||||
3 => {
|
||||
let (dim1, dim2, dim3) = indices.dims3()?;
|
||||
let mut v = vec![off_value; depth * dim1 * dim2 * dim3];
|
||||
let idx =
|
||||
|i: usize, j: usize, k: usize, depth: usize, dim2: usize, dim3: usize| -> usize {
|
||||
i * depth * dim2 * dim3 + j * depth * dim3 + k * depth
|
||||
};
|
||||
let iter = (0..dim1)
|
||||
.flat_map(|i| (0..dim2).flat_map(move |j| (0..dim3).map(move |k| (i, j, k))));
|
||||
match dtype {
|
||||
DType::U8 => {
|
||||
let index = indices.to_vec3::<u8>()?;
|
||||
for (i, j, k) in iter {
|
||||
set_usize_value(
|
||||
index[i][j][k] as usize,
|
||||
idx(i, j, k, depth, dim2, dim3),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
let index = indices.to_vec3::<u32>()?;
|
||||
for (i, j, k) in iter {
|
||||
set_usize_value(
|
||||
index[i][j][k] as usize,
|
||||
idx(i, j, k, depth, dim2, dim3),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
let index = indices.to_vec3::<i64>()?;
|
||||
for (i, j, k) in iter {
|
||||
set_int64_value(
|
||||
index[i][j][k],
|
||||
idx(i, j, k, depth, dim2, dim3),
|
||||
depth,
|
||||
&mut v,
|
||||
on_value,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
d => unsupported_dtype(d)?,
|
||||
};
|
||||
Tensor::from_vec(v, (dim1, dim2, dim3, depth), indices.device())
|
||||
dtype => {
|
||||
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
|
||||
}
|
||||
_ => {
|
||||
bail!("one_hot: rank {} is not supported.", rank)
|
||||
}
|
||||
}
|
||||
};
|
||||
Tensor::from_vec(out, target_shape, indices.device())
|
||||
}
|
||||
|
||||
fn unsupported_dtype(dtype: DType) -> Result<()> {
|
||||
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
|
||||
}
|
||||
|
||||
// Set unsigned usize index values to the given value.
|
||||
fn set_usize_value<D: WithDType>(
|
||||
value: usize,
|
||||
idx: usize,
|
||||
depth: usize,
|
||||
v: &mut Vec<D>,
|
||||
on_value: D,
|
||||
) -> Result<()> {
|
||||
if value >= depth {
|
||||
bail!("one_hot: index value {value} exceeds depth {depth}")
|
||||
}
|
||||
let idx = idx + value;
|
||||
if idx >= v.len() {
|
||||
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
|
||||
}
|
||||
v[idx] = on_value;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Set signed integer index values to the given value.
|
||||
// Signed integer values are only permitted for `-1` values.
|
||||
// Otherwise, the value must be positive for unsigned usize values.
|
||||
// This method will only case i64 values to usize values if the value is positive,
|
||||
// otherwise the method will bail.
|
||||
fn set_int64_value<D: WithDType>(
|
||||
value: i64,
|
||||
idx: usize,
|
||||
fn set_at_index<D: WithDType, I: Into<i64>>(
|
||||
value: I,
|
||||
offset: usize,
|
||||
depth: usize,
|
||||
v: &mut Vec<D>,
|
||||
on_value: D,
|
||||
) -> Result<()> {
|
||||
let value = value.into();
|
||||
// Skip for an entire row of off_values
|
||||
if value == -1 {
|
||||
return Ok(());
|
||||
@ -289,5 +137,14 @@ fn set_int64_value<D: WithDType>(
|
||||
"one_hot: invalid negative index value {value}, expected a positive index value or -1"
|
||||
);
|
||||
}
|
||||
set_usize_value(value as usize, idx, depth, v, on_value)
|
||||
let value = value as usize;
|
||||
if value >= depth {
|
||||
bail!("one_hot: index value {value} exceeds depth {depth}")
|
||||
}
|
||||
let idx = offset + value;
|
||||
if idx >= v.len() {
|
||||
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
|
||||
}
|
||||
v[idx] = on_value;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user