Merge pull request #1323 from huggingface/metal3

Adding the test scaffolding.
This commit is contained in:
Nicolas Patry
2023-11-27 13:06:01 +01:00
committed by GitHub
7 changed files with 127 additions and 54 deletions

View File

@ -514,7 +514,6 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
let out = self.to_cpu_storage()?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
@ -616,17 +615,17 @@ impl BackendStorage for MetalStorage {
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,6 +15,12 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}