mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add an enum for scalar values. (#2909)
* Add a scalar enum type. * Add a bit more to the scalar type. * Small tweak. * More scalar usage.
This commit is contained in:
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
out[tid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
|
@ -2570,7 +2570,7 @@ pub fn call_const_fill(
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
v: impl EncoderParam,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||
name: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
let value = f(value);
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitive!(f64);
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
|
Reference in New Issue
Block a user