Cleanup fixed a few ops removed debugging scaffolding.

This commit is contained in:
Nicolas Patry
2023-11-10 23:00:32 +01:00
committed by Nicolas Patry
parent 7cfffcac10
commit 2813fb5dbc
7 changed files with 28 additions and 55 deletions

View File

@ -105,8 +105,6 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
// TODO Is this necessary
// self.buffer.synchronize();
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
@ -140,6 +138,7 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
assert_eq!(dtype, DType::F32);
let mut buffer = device.new_buffer(el, self.dtype);
@ -173,10 +172,10 @@ impl BackendStorage for MetalStorage {
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@ -269,13 +268,6 @@ impl BackendStorage for MetalStorage {
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.wait_until_scheduled();
// debug!(
// "cast {:?} - {:?} - {:?}",
// dtype,
// self.buffer.length(),
// buffer.length()
// );
Ok(Self {
buffer,
device: device.clone(),
@ -290,7 +282,7 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -300,6 +292,7 @@ impl BackendStorage for MetalStorage {
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@ -337,7 +330,9 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
{
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -380,10 +375,10 @@ impl BackendStorage for MetalStorage {
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&mut buffer,
)
.map_err(MetalError::from)?;
@ -420,11 +415,14 @@ impl BackendStorage for MetalStorage {
"where_u8_f32",
&dims,
&self.buffer,
(layout.stride(), layout.start_offset()),
(
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(&t_l.stride(), t_l.start_offset()),
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset()),
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&mut buffer,
)
.map_err(MetalError::from)?;
@ -511,7 +509,9 @@ impl BackendStorage for MetalStorage {
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
assert!(src_l.is_contiguous());
assert!(src_l.start_offset() == 0);
assert!(ids_l.is_contiguous());
assert!(ids_l.start_offset() == 0);
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
@ -681,6 +681,7 @@ impl BackendStorage for MetalStorage {
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
if el_count == 0 {
return Ok(());
}
@ -699,15 +700,13 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
&self.buffer,
&src_l.stride(),
src_l.start_offset(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&mut dst.buffer,
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// todo!("Output {:?}", dst.buffer.read_to_vec::<f32>(10));
// }
Ok(())
}
}
@ -732,24 +731,11 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
command_queue,
// command_buffer,
kernels,
})
}
@ -819,9 +805,6 @@ impl BackendDevice for MetalDevice {
option,
),
};
// TODO is that necessary ?
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
// debug!("Allocate 2 - buffer size {}", buffer.length());
Ok(Self::Storage {
buffer,
device: self.clone(),

View File

@ -157,8 +157,6 @@ pub(crate) fn from_storage<S: Into<Shape>>(
) -> Tensor {
let dtype = storage.dtype();
let device = storage.device();
let shape = shape.into();
// println!("{:?} {storage:?}", shape);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)),
@ -168,11 +166,7 @@ pub(crate) fn from_storage<S: Into<Shape>>(
dtype,
device,
};
let result = Tensor(Arc::new(tensor_));
// todo!(" from_storage");
// let result = result.to_device(&Device::Cpu).unwrap();
// todo!(" {result}");
result
Tensor(Arc::new(tensor_))
}
impl Tensor {
@ -1869,7 +1863,10 @@ impl Tensor {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.

View File

@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.get_ids()
.to_vec();
println!("{tokens:?}");
let start_gen = std::time::Instant::now();
for index in 0..1 {
for index in 0.. {
if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
// println!("Input {}", input);
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {

View File

@ -150,7 +150,7 @@ macro_rules! ops{
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
}
pub mod binary {
ops!(add, sub, mul, div);

View File

@ -63,6 +63,7 @@ UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
@ -73,6 +74,7 @@ BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -9,7 +9,6 @@ pub struct Embedding {
impl Embedding {
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
// todo!("Embedding {embeddings}");
Self {
embeddings,
hidden_size,

View File

@ -156,7 +156,6 @@ impl CausalSelfAttention {
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
todo!("X {x1}");
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
@ -174,7 +173,6 @@ impl CausalSelfAttention {
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
todo!("X {q}");
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
@ -297,7 +295,6 @@ impl Block {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
todo!("---X {}", x);
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
@ -330,7 +327,6 @@ impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
//println!("Embeddings {}", self.wte.embeddings());
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}