mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Helper function to build 3d arrays.
This commit is contained in:
@ -61,6 +61,25 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
for &[[[S; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::new();
|
||||
vec.reserve(N1 * N2 * N3);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
vec.extend(self[i1][i2])
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
|
@ -72,9 +72,8 @@ fn tensor_2d_transpose() -> Result<()> {
|
||||
|
||||
#[test]
|
||||
fn softmax() -> Result<()> {
|
||||
let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.];
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = tensor.reshape((2, 2, 3))?;
|
||||
let t0 = tensor.log()?.softmax(0)?;
|
||||
let t1 = tensor.log()?.softmax(1)?;
|
||||
let t2 = tensor.log()?.softmax(2)?;
|
||||
|
Reference in New Issue
Block a user