Skip to content

Commit 11f4c15

Browse files
committed
CI: skip cotengra numpy install for some
1 parent 7e11612 commit 11f4c15

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
install: |
6161
apt-get update
6262
apt-get install -y --no-install-recommends python3 python3-pip
63-
pip3 install -U pip pytest numpy cotengra
63+
pip3 install -U pip pytest # numpy cotengra
6464
run: |
6565
set -e
6666
pip3 install cotengrust --find-links dist --force-reinstall

tests/test_cotengrust.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
import pytest
2-
import numpy as np
3-
from numpy.testing import assert_allclose
4-
import cotengra as ctg
2+
3+
try:
4+
import cotengra as ctg
5+
6+
ctg_missing = False
7+
except ImportError:
8+
ctg_missing = True
9+
ctg = None
10+
511
import cotengrust as ctgr
612

713

14+
requires_cotengra = pytest.mark.skipif(ctg_missing, reason="requires cotengra")
15+
16+
17+
@pytest.mark.parametrize("which", ["greedy", "optimal"])
18+
def test_basic_call(which):
19+
inputs = [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'a')]
20+
output = ('b', 'd')
21+
size_dict = {'a': 2, 'b': 3, 'c': 4, 'd': 5}
22+
path = {
23+
"greedy": ctgr.optimize_greedy,
24+
"optimal": ctgr.optimize_optimal,
25+
}[
26+
which
27+
](inputs, output, size_dict)
28+
assert all(len(con) <= 2 for con in path)
29+
30+
831
def find_output_str(lhs):
932
tmp_lhs = lhs.replace(",", "")
1033
return "".join(s for s in sorted(set(tmp_lhs)) if tmp_lhs.count(s) == 1)
@@ -21,20 +44,16 @@ def eq_to_inputs_output(eq):
2144

2245

2346
def get_rand_size_dict(inputs, d_min=2, d_max=3):
47+
import random
48+
2449
size_dict = {}
2550
for term in inputs:
2651
for ix in term:
2752
if ix not in size_dict:
28-
size_dict[ix] = np.random.randint(d_min, d_max + 1)
53+
size_dict[ix] = random.randint(d_min, d_max)
2954
return size_dict
3055

3156

32-
def build_arrays(inputs, size_dict):
33-
return [
34-
np.random.randn(*[size_dict[ix] for ix in term]) for term in inputs
35-
]
36-
37-
3857
# these are taken from opt_einsum
3958
test_case_eqs = [
4059
# Test scalar-like operations
@@ -120,24 +139,26 @@ def build_arrays(inputs, size_dict):
120139
]
121140

122141

142+
@requires_cotengra
123143
@pytest.mark.parametrize("eq", test_case_eqs)
124144
@pytest.mark.parametrize("which", ["greedy", "optimal"])
125145
def test_manual_cases(eq, which):
126146
inputs, output = eq_to_inputs_output(eq)
127147
size_dict = get_rand_size_dict(inputs)
128-
arrays = build_arrays(inputs, size_dict)
129-
expected = np.einsum(eq, *arrays, optimize=True)
130148
path = {
131149
"greedy": ctgr.optimize_greedy,
132150
"optimal": ctgr.optimize_optimal,
133151
}[
134152
which
135153
](inputs, output, size_dict)
136154
assert all(len(con) <= 2 for con in path)
137-
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
138-
assert_allclose(tree.contract(arrays), expected)
155+
tree = ctg.ContractionTree.from_path(
156+
inputs, output, size_dict, path=path, check=True
157+
)
158+
assert tree.is_complete()
139159

140160

161+
@requires_cotengra
141162
@pytest.mark.parametrize("seed", range(10))
142163
@pytest.mark.parametrize("which", ["greedy", "optimal"])
143164
def test_basic_rand(seed, which):
@@ -151,22 +172,20 @@ def test_basic_rand(seed, which):
151172
d_max=3,
152173
seed=seed,
153174
)
154-
eq = ",".join(map("".join, inputs)) + "->" + "".join(output)
155-
156175
path = {
157176
"greedy": ctgr.optimize_greedy,
158177
"optimal": ctgr.optimize_optimal,
159178
}[
160179
which
161180
](inputs, output, size_dict)
162181
assert all(len(con) <= 2 for con in path)
163-
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
164-
arrays = [np.random.randn(*s) for s in shapes]
165-
assert_allclose(
166-
tree.contract(arrays), np.einsum(eq, *arrays, optimize=True)
182+
tree = ctg.ContractionTree.from_path(
183+
inputs, output, size_dict, path=path, check=True
167184
)
185+
assert tree.is_complete()
168186

169187

188+
@requires_cotengra
170189
def test_optimal_lattice_eq():
171190
inputs, output, _, size_dict = ctg.utils.lattice_equation(
172191
[4, 5], d_max=3, seed=42

0 commit comments

Comments
 (0)