Tweak some metal tests. (#2528)

This commit is contained in:
Laurent Mazare
2024-10-02 10:22:31 +02:00
committed by GitHub
parent a2bcc227df
commit fd08d3d0a4
2 changed files with 23 additions and 62 deletions

View File

@ -2372,16 +2372,11 @@ pub fn call_const_fill(
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (output, v, length)); set_params!(encoder, (output, v, length));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(()) Ok(())
} }

View File

@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() {
assert_eq!(results, expected); assert_eq!(results, expected);
} }
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
#[test] #[test]
fn const_fill() { fn const_fill() {
let fills = [ fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
"fill_u8", let dev = device();
"fill_u32", let kernels = Kernels::new();
"fill_i64", let command_queue = dev.new_command_queue();
"fill_f16", let command_buffer = command_queue.new_command_buffer();
"fill_bf16", let buffer = dev.new_buffer(
"fill_f32", (len * std::mem::size_of::<T>()) as u64,
]; MTLResourceOptions::StorageModePrivate,
);
for name in fills { call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.); let value = rand::thread_rng().gen_range(1. ..19.);
let v = constant_fill::<T>(name, len, value);
match name { assert_eq!(v, vec![f(value); len])
"fill_u8" => {
let v = constant_fill::<u8>(name, len, value);
assert_eq!(v, vec![value as u8; len])
}
"fill_u32" => {
let v = constant_fill::<u32>(name, len, value);
assert_eq!(v, vec![value as u32; len])
}
"fill_i64" => {
let v = constant_fill::<i64>(name, len, value);
assert_eq!(v, vec![value as i64; len])
}
"fill_f16" => {
let v = constant_fill::<f16>(name, len, value);
assert_eq!(v, vec![f16::from_f32(value); len])
}
"fill_bf16" => {
let v = constant_fill::<bf16>(name, len, value);
assert_eq!(v, vec![bf16::from_f32(value); len])
}
"fill_f32" => {
let v = constant_fill::<f32>(name, len, value);
assert_eq!(v, vec![value; len])
}
_ => unimplemented!(),
};
} }
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);
test::<i64, _>("fill_i64", |v| v as i64);
test::<f16, _>("fill_f16", f16::from_f32);
test::<bf16, _>("fill_bf16", bf16::from_f32);
test::<f32, _>("fill_f32", |v| v);
} }