mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Fixing tests + matmul from MFA
This commit is contained in:
@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -276,17 +276,17 @@ impl BackendStorage for MetalStorage {
|
||||
self.device.wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||
DType::F32 => {
|
||||
let vec = buffer.read_to_vec(length / size);
|
||||
let vec = read_to_vec(&buffer, length / size);
|
||||
// println!("Got back {:?}", &vec[..1]);
|
||||
Ok(CpuStorage::F32(vec))
|
||||
}
|
||||
DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||
}
|
||||
}
|
||||
|
||||
@ -944,6 +944,8 @@ impl BackendStorage for MetalStorage {
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
// println!("MATMUL {b} {m} {n} {k}");
|
||||
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
@ -952,16 +954,17 @@ impl BackendStorage for MetalStorage {
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// Create kernel
|
||||
command_buffer.commit();
|
||||
self.device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
@ -1138,3 +1141,10 @@ impl BackendDevice for MetalDevice {
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -2,6 +2,13 @@ use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
@ -47,7 +54,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
@ -72,7 +79,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(x.len())
|
||||
read_to_vec(&output, x.len())
|
||||
}
|
||||
|
||||
fn run_strided<T: Clone>(
|
||||
@ -103,7 +110,7 @@ fn run_strided<T: Clone>(
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -261,7 +268,7 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<U>(v.len())
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -311,7 +318,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
fn run_affine_strided<T: Clone>(
|
||||
@ -347,7 +354,7 @@ fn run_affine_strided<T: Clone>(
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let len: usize = shape.iter().product();
|
||||
output.read_to_vec::<T>(len)
|
||||
read_to_vec(&output, len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -468,7 +475,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
dst_buffer.read_to_vec::<T>(dst_el)
|
||||
read_to_vec(&dst_buffer, dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -534,7 +541,7 @@ fn index_add() {
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
@ -574,7 +581,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(out_length)
|
||||
read_to_vec(&output, out_length)
|
||||
}
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
@ -598,7 +605,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -720,7 +727,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(length)
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -744,3 +751,92 @@ fn where_cond() {
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: Vec<usize>,
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
&rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
|
||||
// OFFSET
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
}
|
||||
|
@ -144,7 +144,6 @@ impl RotaryEmbedding {
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
// todo!("{}", sin);
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user