mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Sketch the Falcon model. (#93)
* Sketch the Falcon model. * Add more substance to the falcon example. * Falcon (wip). * Falcon (wip again). * Falcon inference. * Get the weights from the api and properly generate the model. * Use the proper model. * Fix the file/revision names. * Fix bias handling. * Recompute the rot embeddings. * Fix the input shape. * Add the release-with-debug profile. * Silly bugfix. * More bugfixes. * Stricter shape checking in matmul.
This commit is contained in:
@ -476,27 +476,28 @@ impl Tensor {
|
||||
let dim = a_dims.len();
|
||||
|
||||
if dim < 2 || b_dims.len() != dim {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
});
|
||||
})?
|
||||
}
|
||||
|
||||
let m = a_dims[dim - 2];
|
||||
let k = a_dims[dim - 1];
|
||||
let k2 = b_dims[dim - 2];
|
||||
let n = b_dims[dim - 1];
|
||||
if k != k2 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
});
|
||||
}
|
||||
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||
if k != k2 || batching != batching_b {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
})?
|
||||
}
|
||||
|
||||
let storage = self.storage.matmul(
|
||||
&rhs.storage,
|
||||
@ -660,6 +661,11 @@ impl Tensor {
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(self.shape(), "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
pub fn layout(&self) -> &Layout {
|
||||
&self.layout
|
||||
}
|
||||
|
Reference in New Issue
Block a user