mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Tweak some metal tests. (#2528)
This commit is contained in:
@ -2372,16 +2372,11 @@ pub fn call_const_fill(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (output, v, length));
|
set_params!(encoder, (output, v, length));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() {
|
|||||||
assert_eq!(results, expected);
|
assert_eq!(results, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
|
||||||
let dev = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = dev.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
let buffer = dev.new_buffer(
|
|
||||||
(len * std::mem::size_of::<T>()) as u64,
|
|
||||||
MTLResourceOptions::StorageModePrivate,
|
|
||||||
);
|
|
||||||
|
|
||||||
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
|
|
||||||
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
read_to_vec::<T>(&buffer, len)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn const_fill() {
|
fn const_fill() {
|
||||||
let fills = [
|
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||||
"fill_u8",
|
let dev = device();
|
||||||
"fill_u32",
|
let kernels = Kernels::new();
|
||||||
"fill_i64",
|
let command_queue = dev.new_command_queue();
|
||||||
"fill_f16",
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
"fill_bf16",
|
let buffer = dev.new_buffer(
|
||||||
"fill_f32",
|
(len * std::mem::size_of::<T>()) as u64,
|
||||||
];
|
MTLResourceOptions::StorageModePrivate,
|
||||||
|
);
|
||||||
for name in fills {
|
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
read_to_vec::<T>(&buffer, len)
|
||||||
|
}
|
||||||
|
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||||
|
let v = constant_fill::<T>(name, len, value);
|
||||||
match name {
|
assert_eq!(v, vec![f(value); len])
|
||||||
"fill_u8" => {
|
|
||||||
let v = constant_fill::<u8>(name, len, value);
|
|
||||||
assert_eq!(v, vec![value as u8; len])
|
|
||||||
}
|
|
||||||
"fill_u32" => {
|
|
||||||
let v = constant_fill::<u32>(name, len, value);
|
|
||||||
assert_eq!(v, vec![value as u32; len])
|
|
||||||
}
|
|
||||||
"fill_i64" => {
|
|
||||||
let v = constant_fill::<i64>(name, len, value);
|
|
||||||
assert_eq!(v, vec![value as i64; len])
|
|
||||||
}
|
|
||||||
"fill_f16" => {
|
|
||||||
let v = constant_fill::<f16>(name, len, value);
|
|
||||||
assert_eq!(v, vec![f16::from_f32(value); len])
|
|
||||||
}
|
|
||||||
"fill_bf16" => {
|
|
||||||
let v = constant_fill::<bf16>(name, len, value);
|
|
||||||
assert_eq!(v, vec![bf16::from_f32(value); len])
|
|
||||||
}
|
|
||||||
"fill_f32" => {
|
|
||||||
let v = constant_fill::<f32>(name, len, value);
|
|
||||||
assert_eq!(v, vec![value; len])
|
|
||||||
}
|
|
||||||
_ => unimplemented!(),
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
test::<u8, _>("fill_u8", |v| v as u8);
|
||||||
|
test::<u32, _>("fill_u32", |v| v as u32);
|
||||||
|
test::<i64, _>("fill_i64", |v| v as i64);
|
||||||
|
test::<f16, _>("fill_f16", f16::from_f32);
|
||||||
|
test::<bf16, _>("fill_bf16", bf16::from_f32);
|
||||||
|
test::<f32, _>("fill_f32", |v| v);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user