mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Adding matmul
This commit is contained in:
13
.pre-commit-config.yaml
Normal file
13
.pre-commit-config.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/Narsil/pre-commit-rust
|
||||||
|
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
|
||||||
|
hooks:
|
||||||
|
- id: fmt
|
||||||
|
name: "Rust (fmt)"
|
||||||
|
- id: clippy
|
||||||
|
name: "Rust (clippy)"
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
"--",
|
||||||
|
"-Dwarnings",
|
||||||
|
]
|
@ -178,7 +178,7 @@ impl Tensor {
|
|||||||
device: Device,
|
device: Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.storage(a);
|
let storage = device.storage(a)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let is_variable = false;
|
let is_variable = false;
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
@ -514,7 +514,7 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Matmul(lhs, rhs) => {
|
Op::Matmul(_lhs, _rhs) => {
|
||||||
// let (m, k) = lhs.shape;
|
// let (m, k) = lhs.shape;
|
||||||
// let n = rhs.shape.1;
|
// let n = rhs.shape.1;
|
||||||
// let strides = (m, n).strides();
|
// let strides = (m, n).strides();
|
||||||
@ -539,12 +539,12 @@ impl Tensor {
|
|||||||
// rhs.strides,
|
// rhs.strides,
|
||||||
// );
|
// );
|
||||||
|
|
||||||
let lhs_grad = grad.matmul(rhs)?;
|
// let lhs_grad = grad.matmul(rhs)?;
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
// let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
// *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
// let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
// let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
// *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
|
Reference in New Issue
Block a user