mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fixing tests + matmul from MFA
This commit is contained in:
@ -276,17 +276,17 @@ impl BackendStorage for MetalStorage {
|
||||
self.device.wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||
DType::F32 => {
|
||||
let vec = buffer.read_to_vec(length / size);
|
||||
let vec = read_to_vec(&buffer, length / size);
|
||||
// println!("Got back {:?}", &vec[..1]);
|
||||
Ok(CpuStorage::F32(vec))
|
||||
}
|
||||
DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||
}
|
||||
}
|
||||
|
||||
@ -944,6 +944,8 @@ impl BackendStorage for MetalStorage {
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
// println!("MATMUL {b} {m} {n} {k}");
|
||||
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
@ -952,16 +954,17 @@ impl BackendStorage for MetalStorage {
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// Create kernel
|
||||
command_buffer.commit();
|
||||
self.device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
@ -1138,3 +1141,10 @@ impl BackendDevice for MetalDevice {
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
Reference in New Issue
Block a user