Debugging index_add.

This commit is contained in:
Ivar Flakstad
2023-11-03 12:08:58 +01:00
parent f57e3164ae
commit 0794e70a19
5 changed files with 183 additions and 28 deletions

View File

@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::{MTLResourceOptions, Buffer};
use metal::{Buffer, MTLResourceOptions};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -48,12 +48,9 @@ impl MetalDevice {
self.registry_id()
}
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer(
size,
MTLResourceOptions::empty(),
)
self.device.new_buffer(size, MTLResourceOptions::empty())
}
}
@ -80,9 +77,11 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype{
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
dtype => todo!("Unsupported dtype {dtype:?}")
match self.dtype {
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
}
}
@ -123,7 +122,11 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { buffer, device, dtype })
Ok(Self {
buffer,
device,
dtype,
})
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -295,7 +298,11 @@ impl MetalStorage {
});
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
println!(
"Didn't implemented non contiguous matmul yet {:?} {:?}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
@ -361,7 +368,6 @@ impl MetalStorage {
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
@ -446,13 +452,25 @@ impl BackendDevice for MetalDevice {
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)

View File

@ -349,12 +349,9 @@ impl crate::CustomOp1 for QTensor {
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{
Ok((
device.storage_from_cpu_storage(&cpu_storage)?,
dst_shape,
))
}else{
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}