11import pytest
2- import numpy as np
3- from numpy .testing import assert_allclose
4- import cotengra as ctg
2+
3+ try :
4+ import cotengra as ctg
5+
6+ ctg_missing = False
7+ except ImportError :
8+ ctg_missing = True
9+ ctg = None
10+
511import cotengrust as ctgr
612
713
14+ requires_cotengra = pytest .mark .skipif (ctg_missing , reason = "requires cotengra" )
15+
16+
17+ @pytest .mark .parametrize ("which" , ["greedy" , "optimal" ])
18+ def test_basic_call (which ):
19+ inputs = [('a' , 'b' ), ('b' , 'c' ), ('c' , 'd' ), ('d' , 'a' )]
20+ output = ('b' , 'd' )
21+ size_dict = {'a' : 2 , 'b' : 3 , 'c' : 4 , 'd' : 5 }
22+ path = {
23+ "greedy" : ctgr .optimize_greedy ,
24+ "optimal" : ctgr .optimize_optimal ,
25+ }[
26+ which
27+ ](inputs , output , size_dict )
28+ assert all (len (con ) <= 2 for con in path )
29+
30+
831def find_output_str (lhs ):
932 tmp_lhs = lhs .replace ("," , "" )
1033 return "" .join (s for s in sorted (set (tmp_lhs )) if tmp_lhs .count (s ) == 1 )
@@ -21,20 +44,16 @@ def eq_to_inputs_output(eq):
2144
2245
2346def get_rand_size_dict (inputs , d_min = 2 , d_max = 3 ):
47+ import random
48+
2449 size_dict = {}
2550 for term in inputs :
2651 for ix in term :
2752 if ix not in size_dict :
28- size_dict [ix ] = np . random .randint (d_min , d_max + 1 )
53+ size_dict [ix ] = random .randint (d_min , d_max )
2954 return size_dict
3055
3156
32- def build_arrays (inputs , size_dict ):
33- return [
34- np .random .randn (* [size_dict [ix ] for ix in term ]) for term in inputs
35- ]
36-
37-
3857# these are taken from opt_einsum
3958test_case_eqs = [
4059 # Test scalar-like operations
@@ -120,24 +139,26 @@ def build_arrays(inputs, size_dict):
120139]
121140
122141
142+ @requires_cotengra
123143@pytest .mark .parametrize ("eq" , test_case_eqs )
124144@pytest .mark .parametrize ("which" , ["greedy" , "optimal" ])
125145def test_manual_cases (eq , which ):
126146 inputs , output = eq_to_inputs_output (eq )
127147 size_dict = get_rand_size_dict (inputs )
128- arrays = build_arrays (inputs , size_dict )
129- expected = np .einsum (eq , * arrays , optimize = True )
130148 path = {
131149 "greedy" : ctgr .optimize_greedy ,
132150 "optimal" : ctgr .optimize_optimal ,
133151 }[
134152 which
135153 ](inputs , output , size_dict )
136154 assert all (len (con ) <= 2 for con in path )
137- tree = ctg .ContractionTree .from_path (inputs , output , size_dict , path = path )
138- assert_allclose (tree .contract (arrays ), expected )
155+ tree = ctg .ContractionTree .from_path (
156+ inputs , output , size_dict , path = path , check = True
157+ )
158+ assert tree .is_complete ()
139159
140160
161+ @requires_cotengra
141162@pytest .mark .parametrize ("seed" , range (10 ))
142163@pytest .mark .parametrize ("which" , ["greedy" , "optimal" ])
143164def test_basic_rand (seed , which ):
@@ -151,22 +172,20 @@ def test_basic_rand(seed, which):
151172 d_max = 3 ,
152173 seed = seed ,
153174 )
154- eq = "," .join (map ("" .join , inputs )) + "->" + "" .join (output )
155-
156175 path = {
157176 "greedy" : ctgr .optimize_greedy ,
158177 "optimal" : ctgr .optimize_optimal ,
159178 }[
160179 which
161180 ](inputs , output , size_dict )
162181 assert all (len (con ) <= 2 for con in path )
163- tree = ctg .ContractionTree .from_path (inputs , output , size_dict , path = path )
164- arrays = [np .random .randn (* s ) for s in shapes ]
165- assert_allclose (
166- tree .contract (arrays ), np .einsum (eq , * arrays , optimize = True )
182+ tree = ctg .ContractionTree .from_path (
183+ inputs , output , size_dict , path = path , check = True
167184 )
185+ assert tree .is_complete ()
168186
169187
188+ @requires_cotengra
170189def test_optimal_lattice_eq ():
171190 inputs , output , _ , size_dict = ctg .utils .lattice_equation (
172191 [4 , 5 ], d_max = 3 , seed = 42
0 commit comments