mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Merge pull request #1323 from huggingface/metal3
Adding the test scaffolding.
This commit is contained in:
@ -514,7 +514,6 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let out = self.to_cpu_storage()?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||
@ -616,17 +615,17 @@ impl BackendStorage for MetalStorage {
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user