mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Tweak some metal tests. (#2528)
This commit is contained in:
@ -2372,16 +2372,11 @@ pub fn call_const_fill(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (output, v, length));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() {
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let buffer = dev.new_buffer(
|
||||
(len * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
);
|
||||
|
||||
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
let fills = [
|
||||
"fill_u8",
|
||||
"fill_u32",
|
||||
"fill_i64",
|
||||
"fill_f16",
|
||||
"fill_bf16",
|
||||
"fill_f32",
|
||||
];
|
||||
|
||||
for name in fills {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let buffer = dev.new_buffer(
|
||||
(len * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
);
|
||||
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
|
||||
match name {
|
||||
"fill_u8" => {
|
||||
let v = constant_fill::<u8>(name, len, value);
|
||||
assert_eq!(v, vec![value as u8; len])
|
||||
}
|
||||
"fill_u32" => {
|
||||
let v = constant_fill::<u32>(name, len, value);
|
||||
assert_eq!(v, vec![value as u32; len])
|
||||
}
|
||||
"fill_i64" => {
|
||||
let v = constant_fill::<i64>(name, len, value);
|
||||
assert_eq!(v, vec![value as i64; len])
|
||||
}
|
||||
"fill_f16" => {
|
||||
let v = constant_fill::<f16>(name, len, value);
|
||||
assert_eq!(v, vec![f16::from_f32(value); len])
|
||||
}
|
||||
"fill_bf16" => {
|
||||
let v = constant_fill::<bf16>(name, len, value);
|
||||
assert_eq!(v, vec![bf16::from_f32(value); len])
|
||||
}
|
||||
"fill_f32" => {
|
||||
let v = constant_fill::<f32>(name, len, value);
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
test::<i64, _>("fill_i64", |v| v as i64);
|
||||
test::<f16, _>("fill_f16", f16::from_f32);
|
||||
test::<bf16, _>("fill_bf16", bf16::from_f32);
|
||||
test::<f32, _>("fill_f32", |v| v);
|
||||
}
|
||||
|
Reference in New Issue
Block a user