From 86e4cbbc3d61739c26da0cca93c7532023480789 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Jun 2023 17:01:32 +0200 Subject: [PATCH] Adding matmul --- .pre-commit-config.yaml | 13 +++++++++++++ src/tensor.rs | 16 ++++++++-------- 2 files changed, 21 insertions(+), 8 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..ad3bc0c2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/Narsil/pre-commit-rust + rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0 + hooks: + - id: fmt + name: "Rust (fmt)" + - id: clippy + name: "Rust (clippy)" + args: + [ + "--", + "-Dwarnings", + ] diff --git a/src/tensor.rs b/src/tensor.rs index e55050c6..7607171c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -178,7 +178,7 @@ impl Tensor { device: Device, ) -> Result { let shape = shape.into(); - let storage = device.storage(a); + let storage = device.storage(a)?; let stride = shape.stride_contiguous(); let is_variable = false; let tensor_ = Tensor_ { @@ -514,7 +514,7 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Matmul(lhs, rhs) => { + Op::Matmul(_lhs, _rhs) => { // let (m, k) = lhs.shape; // let n = rhs.shape.1; // let strides = (m, n).strides(); @@ -539,12 +539,12 @@ impl Tensor { // rhs.strides, // ); - let lhs_grad = grad.matmul(rhs)?; - let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + // let lhs_grad = grad.matmul(rhs)?; + // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?;