mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Debugging index_add.
This commit is contained in:
@ -9,7 +9,7 @@ use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::{MTLResourceOptions, Buffer};
|
||||
use metal::{Buffer, MTLResourceOptions};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -48,12 +48,9 @@ impl MetalDevice {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||
self.device.new_buffer(
|
||||
size,
|
||||
MTLResourceOptions::empty(),
|
||||
)
|
||||
self.device.new_buffer(size, MTLResourceOptions::empty())
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,9 +77,11 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self.dtype{
|
||||
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
|
||||
dtype => todo!("Unsupported dtype {dtype:?}")
|
||||
match self.dtype {
|
||||
DType::F32 => Ok(CpuStorage::F32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
dtype => todo!("Unsupported dtype {dtype:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,7 +122,11 @@ impl BackendStorage for MetalStorage {
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
todo!("Implement the kernel calling");
|
||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||
Ok(Self { buffer, device, dtype })
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -295,7 +298,11 @@ impl MetalStorage {
|
||||
});
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
||||
println!(
|
||||
"Didn't implemented non contiguous matmul yet {:?} {:?}",
|
||||
lhs_l.is_contiguous(),
|
||||
rhs_l.is_contiguous()
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
@ -361,7 +368,6 @@ impl MetalStorage {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
@ -446,13 +452,25 @@ impl BackendDevice for MetalDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
fn rand_uniform(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
|
@ -349,12 +349,9 @@ impl crate::CustomOp1 for QTensor {
|
||||
// )?;
|
||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
if let Device::Metal(device) = &self.device{
|
||||
Ok((
|
||||
device.storage_from_cpu_storage(&cpu_storage)?,
|
||||
dst_shape,
|
||||
))
|
||||
}else{
|
||||
if let Device::Metal(device) = &self.device {
|
||||
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
|
||||
} else {
|
||||
crate::bail!("qtensor not on metal device")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user