-
Notifications
You must be signed in to change notification settings - Fork 9
44 lines (36 loc) · 1.13 KB
/
verify_extension_build.yml
File metadata and controls
44 lines (36 loc) · 1.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
name: OEQ C++ Extension Build Verification
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
types: [ labeled ]
permissions:
contents: read
jobs:
verify_cuda_extension:
if: ${{ github.event.label.name == 'ci-ready' || github.event_name != 'pull_request' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip'
cache-dependency-path: '**/requirements_cuda_ci.txt'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo apt-get update
sudo apt install nvidia-cuda-toolkit
pip install -r .github/workflows/requirements_cuda_ci.txt
pip install -e "./openequivariance[jax]"
- name: Test CUDA extension build via import
run: |
pytest tests/import_test.py
export OEQ_JIT_EXTENSION=1
pytest tests/import_test.py
- name: Test JAX extension build
run: |
XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation