mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Adapt the cuda bits.
This commit is contained in:
@ -101,14 +101,9 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn take_impl1<T: Copy>(
|
||||
vs: &[T],
|
||||
ids: &[u32],
|
||||
layout: &Layout,
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
||||
for index in layout.strided_index() {
|
||||
let index = ids[index].try_into()?;
|
||||
@ -610,15 +605,9 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
||||
map1!(rhs, take_impl1, ids, layout, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
|
Reference in New Issue
Block a user