From 2813fb5dbc404db927dab20b59ef3f2b9dbfc389 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 23:00:32 +0100 Subject: [PATCH] Cleanup fixed a few ops removed debugging scaffolding. --- candle-core/src/metal_backend.rs | 55 ++++++++-------------- candle-core/src/tensor.rs | 13 ++--- candle-examples/examples/llama2-c/main.rs | 6 +-- candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/unary.metal | 2 + candle-nn/src/embedding.rs | 1 - candle-transformers/src/models/llama2_c.rs | 4 -- 7 files changed, 28 insertions(+), 55 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 3f58bb9b..597c2f01 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -105,8 +105,6 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - // TODO Is this necessary - // self.buffer.synchronize(); match self.dtype { DType::U8 => Ok(CpuStorage::U8( self.buffer.read_to_vec(self.buffer.length() as usize / 1), @@ -140,6 +138,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; assert!(layout.is_contiguous()); + assert!(layout.start_offset() == 0); assert_eq!(dtype, DType::F32); let mut buffer = device.new_buffer(el, self.dtype); @@ -173,10 +172,10 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - // debug!("TODO reduce_op {op:?} {sum_dims:?}"); assert!(sum_dims.len() == 1); assert!(sum_dims[0] == layout.shape().rank() - 1); assert!(layout.is_contiguous()); + assert!(layout.start_offset() == 0); let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -269,13 +268,6 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); command_buffer.wait_until_completed(); - // command_buffer.wait_until_scheduled(); - // debug!( - // "cast {:?} - {:?} - {:?}", - // dtype, - // self.buffer.length(), - // buffer.length() - // ); Ok(Self { buffer, device: device.clone(), @@ -290,7 +282,7 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); - if layout.is_contiguous() { + if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -300,6 +292,7 @@ impl BackendStorage for MetalStorage { ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("ulog", DType::F32) => contiguous::log::FLOAT, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -337,7 +330,9 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); - if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + { use candle_metal_kernels::binary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -380,10 +375,10 @@ impl BackendStorage for MetalStorage { lhs_l.dims(), &self.buffer, &lhs_l.stride(), - lhs_l.start_offset(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, &rhs_l.stride(), - rhs_l.start_offset(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &mut buffer, ) .map_err(MetalError::from)?; @@ -420,11 +415,14 @@ impl BackendStorage for MetalStorage { "where_u8_f32", &dims, &self.buffer, - (layout.stride(), layout.start_offset()), + ( + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + ), &t.buffer, - (&t_l.stride(), t_l.start_offset()), + (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, - (&f_l.stride(), f_l.start_offset()), + (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), &mut buffer, ) .map_err(MetalError::from)?; @@ -511,7 +509,9 @@ impl BackendStorage for MetalStorage { fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { assert!(src_l.is_contiguous()); + assert!(src_l.start_offset() == 0); assert!(ids_l.is_contiguous()); + assert!(ids_l.start_offset() == 0); let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.shape().elem_count(); @@ -681,6 +681,7 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); + // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -699,15 +700,13 @@ impl BackendStorage for MetalStorage { src_l.dims(), &self.buffer, &src_l.stride(), - src_l.start_offset(), + src_l.start_offset() * self.dtype.size_in_bytes(), &mut dst.buffer, dst_offset, ) .map_err(MetalError::from)?; command_buffer.commit(); command_buffer.wait_until_completed(); - // todo!("Output {:?}", dst.buffer.read_to_vec::(10)); - // } Ok(()) } } @@ -732,24 +731,11 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - // let capture = metal::CaptureManager::shared(); - // let descriptor = metal::CaptureDescriptor::new(); - // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - // descriptor.set_capture_device(&device); - // let mut dir = std::env::current_dir()?; - // dir.push("out.gputrace"); - // descriptor.set_output_url(dir); - - // capture - // .start_capture(&descriptor) - // .map_err(MetalError::from)?; let command_queue = device.new_command_queue(); - // let command_buffer = _command_queue.new_owned_command_buffer(); let kernels = Arc::new(Kernels::new()); Ok(Self { device, command_queue, - // command_buffer, kernels, }) } @@ -819,9 +805,6 @@ impl BackendDevice for MetalDevice { option, ), }; - // TODO is that necessary ? - // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - // debug!("Allocate 2 - buffer size {}", buffer.length()); Ok(Self::Storage { buffer, device: self.clone(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 3965a2ed..ce5858fa 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -157,8 +157,6 @@ pub(crate) fn from_storage>( ) -> Tensor { let dtype = storage.dtype(); let device = storage.device(); - let shape = shape.into(); - // println!("{:?} {storage:?}", shape); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(storage)), @@ -168,11 +166,7 @@ pub(crate) fn from_storage>( dtype, device, }; - let result = Tensor(Arc::new(tensor_)); - // todo!(" from_storage"); - // let result = result.to_device(&Device::Cpu).unwrap(); - // todo!(" {result}"); - result + Tensor(Arc::new(tensor_)) } impl Tensor { @@ -1869,7 +1863,10 @@ impl Tensor { Storage::Metal(metal.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), - (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => { + println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + Storage::Cpu(storage.to_cpu_storage()?) + } (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 11381fbc..0ceb27af 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .get_ids() .to_vec(); - println!("{tokens:?}"); - let start_gen = std::time::Instant::now(); - for index in 0..1 { + for index in 0.. { if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - // println!("Input {}", input); - // println!("Input {}", input.to_device(&candle::Device::Cpu)?); let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e5c9fbae..7288216a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -150,7 +150,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy); + ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); } pub mod binary { ops!(add, sub, mul, div); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index dd137599..eb6424e8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -63,6 +63,7 @@ UNARY_OP(sqr) UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) +UNARY_OP(log) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) @@ -73,6 +74,7 @@ BFLOAT_UNARY_OP(sqr) BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) +BFLOAT_UNARY_OP(log) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 2daac224..52968bc2 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -9,7 +9,6 @@ pub struct Embedding { impl Embedding { pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { - // todo!("Embedding {embeddings}"); Self { embeddings, hidden_size, diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index aba9a547..753770fb 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -156,7 +156,6 @@ impl CausalSelfAttention { let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; let x0 = x.narrow(D::Minus1, 0, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?; - todo!("X {x1}"); let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; @@ -174,7 +173,6 @@ impl CausalSelfAttention { let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let q = self.apply_rotary_emb(&q, index_pos)?; - todo!("X {q}"); let mut k = self.apply_rotary_emb(&k, index_pos)?; if self.cache.use_kv_cache { @@ -297,7 +295,6 @@ impl Block { let residual = x; let x = self.rms_1.forward(x)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; - todo!("---X {}", x); let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) @@ -330,7 +327,6 @@ impl Llama { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; - //println!("Embeddings {}", self.wte.embeddings()); for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; }