-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_weighted_cascade.py
More file actions
42 lines (35 loc) · 1.25 KB
/
test_weighted_cascade.py
File metadata and controls
42 lines (35 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pytest
import numpy as np
from graph_helpers import load_graph_by_name
from helpers import infected_nodes
from experiment import gen_input
from graph_tool import Graph
from itertools import combinations
@pytest.fixture
def g():
return load_graph_by_name('grqc', weighted=True)
@pytest.mark.parametrize("cascade_model", ['si', 'ic'])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("source", [np.random.choice(1000) for i in range(1)])
def test_gen_input(g, cascade_model, weighted, source):
if weighted:
p = g.edge_properties['weights']
else:
p = g.new_edge_property('float')
p.set_value(0.8)
# print(cascade_model, weighted, source)
rows = [gen_input(g, p=p, model=cascade_model, source=source, stop_fraction=0.1)
for i in range(10)]
# make sure no two cascades are the same
# with low probability, this fails
for r1, r2 in combinations(rows, 2):
obs1, c1 = r1[:2]
obs2, c2 = r2[:2]
assert set(obs1) != set(obs2)
# check for cascade size
# only applicable for SI model
if cascade_model == 'si':
for r in rows:
c = r[1]
frac = len(infected_nodes(c)) / g.num_vertices()
assert frac <= 0.11