mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select. * Also remove the cuda kernels.
This commit is contained in:
@ -842,45 +842,35 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
|
||||
/// Returns a tensor with the values from the `self` tensor at the index corresponding to the
|
||||
/// values hold in the `ids` tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - A tensor with dimensions `v, h`.
|
||||
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
|
||||
/// * `rhs` - A tensor with dimensions `v, h`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
|
||||
/// vocabulary size, and `h` the hidden size.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
|
||||
/// let emb = Tensor::embedding(&ids, &rhs)?;
|
||||
/// let emb = values.embedding(&ids)?;
|
||||
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "embedding" }.bt())?
|
||||
} else if rhs.rank() != 2 || ids.rank() != 1 {
|
||||
pub fn embedding(&self, ids: &Self) -> Result<Self> {
|
||||
if self.rank() != 2 || ids.rank() != 1 {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: ids.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
lhs: self.shape().clone(),
|
||||
rhs: ids.shape().clone(),
|
||||
op: "embedding",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.dims1()?;
|
||||
let (_, hidden_size) = rhs.dims2()?;
|
||||
let storage = ids
|
||||
.storage()
|
||||
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = BackpropOp::new2(ids, rhs, Op::Embedding);
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user