mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
fix: add missing gpu fill_* (#996)
This commit is contained in:
@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ones(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn add_mul(device: &Device) -> Result<()> {
|
fn add_mul(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
let dim1 = tensor.dims1()?;
|
let dim1 = tensor.dims1()?;
|
||||||
@ -966,6 +991,7 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
|
test_device!(ones, ones_cpu, ones_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
#include<stdint.h>
|
||||||
#include "cuda_fp16.h"
|
#include "cuda_fp16.h"
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
|
|||||||
buf[i] = value;
|
buf[i] = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 800
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
#endif
|
||||||
|
Reference in New Issue
Block a user