mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
4
.github/workflows/ci_cuda.yaml
vendored
4
.github/workflows/ci_cuda.yaml
vendored
@ -8,6 +8,8 @@ jobs:
|
|||||||
start-runner:
|
start-runner:
|
||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on forks, they won't have access to secrets anyway.
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
@ -70,7 +72,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
steps:
|
steps:
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
fn book_hub_2() {
|
fn book_hub_2() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_2
|
// ANCHOR: book_hub_2
|
||||||
use candle::Device;
|
use candle::Device;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
|||||||
assert_eq!(weights.len(), 206);
|
assert_eq!(weights.len(), 206);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
// #[rustfmt::skip]
|
||||||
#[test]
|
// #[test]
|
||||||
fn book_hub_3() {
|
// fn book_hub_3() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_3
|
// ANCHOR: book_hub_3
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -102,6 +104,7 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
|
Reference in New Issue
Block a user