Adding matmul

This commit is contained in:
Nicolas Patry
2023-06-21 17:01:32 +02:00
parent ce977b489e
commit 86e4cbbc3d
2 changed files with 21 additions and 8 deletions

13
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,13 @@
repos:
- repo: https://github.com/Narsil/pre-commit-rust
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
hooks:
- id: fmt
name: "Rust (fmt)"
- id: clippy
name: "Rust (clippy)"
args:
[
"--",
"-Dwarnings",
]

View File

@ -178,7 +178,7 @@ impl Tensor {
device: Device, device: Device,
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
let storage = device.storage(a); let storage = device.storage(a)?;
let stride = shape.stride_contiguous(); let stride = shape.stride_contiguous();
let is_variable = false; let is_variable = false;
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
@ -514,7 +514,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::Matmul(lhs, rhs) => { Op::Matmul(_lhs, _rhs) => {
// let (m, k) = lhs.shape; // let (m, k) = lhs.shape;
// let n = rhs.shape.1; // let n = rhs.shape.1;
// let strides = (m, n).strides(); // let strides = (m, n).strides();
@ -539,12 +539,12 @@ impl Tensor {
// rhs.strides, // rhs.strides,
// ); // );
let lhs_grad = grad.matmul(rhs)?; // let lhs_grad = grad.matmul(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::Affine { arg, mul, .. } => { Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?; let arg_grad = grad.affine(*mul, 0.)?;