-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcentrality.py
More file actions
executable file
·284 lines (251 loc) · 11.4 KB
/
centrality.py
File metadata and controls
executable file
·284 lines (251 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#!/usr/bin/env python3
# encoding: UTF-8
from __future__ import print_function
from __future__ import division
import argparse
from operator import itemgetter
import itertools
import numpy
import graph_tools
from igraph import Graph, mean, load, InternalError
import nltk
from nltk.corpus import wordnet
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plot
from gensim import models
import visualize
# from WSI import get_subgraph
# Global variables
graph = None
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="verbose output", action="store_true")
parser.add_argument("-s", "--save", help="save graph plots", action="store_true")
parser.add_argument("--arbitrary", help="break ties arbitrarily", action="store_true")
parser.add_argument("infile", help="input file")
parser.add_argument("-ties", help="Accept ties with this many members", type=int, default=1)
return parser.parse_args()
def get_hyponyms(synset):
"""
Get the WordNet hyponyms that are in the model
Params
synset: nltk WordNet synset
"""
hyponyms = synset.hyponyms()
# hyponym_terms = [hyponym.lemmas()[0].name() for hyponym in hyponyms]
# hyponym_terms = [term for term in hyponym_terms if term in dictionary]
# return hyponym_terms
lemmas = set()
for hyponym in hyponyms:
# Use only one lemma per hyponym
lemmas.update(get_lemmas(hyponym)[:1])
return lemmas
def get_lemmas(synset):
"""
Get the lemmas of synset that are in the model
Params
synset: nltk WordNet synset
"""
return [lemma.name() for lemma in synset.lemmas() if lemma.name() in dictionary]
def is_midfrequent(synset):
lemmas = get_lemmas(synset)
return lemmas and lemmas[0] in midfrequent
def get_rank(word):
"""
Get the frequency rank of the word in the word embedding model
"""
return model.vocab[word].index
def get_synsets():
"""
Get all WordNet synsets that match our criteria
"""
try:
synsets = wordnet.all_synsets("n")
# synsets = itertools.islice(wordnet.all_synsets("n"), 10000)
except LookupError:
nltk.download("wordnet")
print("WordNet downloaded, please restart program")
exit()
hypernyms = [synset for synset in synsets if len(get_hyponyms(synset)) >= 5 and is_midfrequent(synset)]
# numpy.random.seed(1)
print("hypernyms:", len(hypernyms))
# return numpy.random.choice(hypernyms, count, replace=False)
return hypernyms
def is_equal(list):
# Convert to set to check that all entries are equal
return len(set(list)) <= 1
def get_score(hypernym, centrality, max_ties):
"""
Scoring function for the evaluation.
Params
hypernym: string containing a hypernym
centrality: ordered list containing (term, centrality) pairs
max_ties: number of ties to allow/consider
Returns
1 if the hypernym is most central, otherwise 0
"""
if max_ties <= 1:
return hypernym == centrality[0][0]
else:
centrality_scores = [score for _, score in centrality[: max_ties]]
candidates = [candidate for candidate, _ in centrality[: max_ties]]
return (is_equal(centrality_scores) and hypernym in candidates) or get_score(hypernym, centrality, max_ties - 1)
def save_plot(method, graph, hypernym, center):
"""
Create graph plot
"""
graph.vs["color"] = "yellow"
# include centrality in node label
if "centrality" in graph.vs.attributes():
graph.vs["label"] = ["%s (%d)" % (node["name"], node["centrality"]) for node in graph.vs]
if hypernym == center:
graph.vs.find(center)["color"] = "green"
graph.vs.find(center)["shape"] = "diamond"
else:
if center:
graph.vs.find(center)["color"] = "blue"
graph.vs.find(center)["shape"] = "square"
graph.vs.find(hypernym)["color"] = "red"
graph.vs.find(hypernym)["shape"] = "up-triangle"
visualize.plot_graph(graph, "temp/images/centrality/%s/%s.PDF" % (method, hypernym), clustering=False, layout="fr") # or kk
def wordnet_subgraph(model, threshold, synset, method, plot=False):
"""
Create a local graph containing the synset lemma and its hyponyms from WordNet
Params:
model: a gensim KeyedVectors model
threshold: for edge inclusion
synset: a WordNet synset
method: centrality measure, one of "pagerank", "degree", "betweenness"
plot: save a graph plot
"""
global fully_connected, ties, graph
hypernym = get_lemmas(synset)[0]
hyponym_terms = get_hyponyms(synset)
# print("synset: %s model contains %d of %d hyponyms" % (synset, len(hyponym_terms), len(synset.hyponyms())))
# print(synset, term, hyponym_terms)
hyponym_terms.add(hypernym) # add hypernym
graph = graph_tools.create_local_graph(model, threshold, hyponym_terms)
centrality = get_centrality(method)
center, center_centrality = centrality[0]
score = get_score(hypernym, centrality, options.ties)
if plot:
# stats for fully connected graphs
if graph.density() == 1:
fully_connected.append(score)
else:
# Stats excluding fully connected graphs
excluding_fully_connected.append(score)
# calculate baseline
baseline_random.append(1 / len(hyponym_terms))
most_frequent = sorted(hyponym_terms, key=get_rank)
# print(list((word, get_rank(word)) for word in most_frequent))
baseline_rank.append(hypernym == most_frequent[0])
# make plot
if options.save:
save_plot(method, graph, hypernym, center)
# log statistics
words = [word for word, _ in centrality]
try:
index = words.index(hypernym)
hypernym_centrality = centrality[index][1]
difference = center_centrality - hypernym_centrality
# log("synset: %s model contains %d of %d hyponyms" % (synset, len(hyponym_terms), len(synset.hyponyms())))
snip = centrality[:index + 2]
log(hypernym, center, hypernym_centrality, center_centrality, difference, snip)
if (hypernym != center and hypernym_centrality == center_centrality) or (hypernym == center and hypernym_centrality == centrality[1][1]):
# tie
# print("tie:", hypernym, center, hypernym_centrality, center_centrality, difference, snip)
ties.append(score)
except ValueError as e:
pass
# print("exception:", e)
return score
def get_centrality(method):
"""
Get centrality for all nodes in the graph
Params:
method - centrality measure to use
"""
global graph
if graph.ecount() == 0:
# No edges
return [(None, None)]
# Invert weights, so similar words are closer
# graph.es["weight"] = [1 - weight for weight in graph.es["weight"]]
graph.es["weight"] = [1 / weight for weight in graph.es["weight"]]
if method == "betweenness":
# graph.vs["centrality"] = graph.betweenness()
graph.vs["centrality"] = graph.betweenness(weights="weight")
elif method == "pagerank":
graph.vs["centrality"] = graph.personalized_pagerank(weights="weight")
elif method == "degree":
graph.vs["centrality"] = graph.vs.degree()
else:
raise ValueError("Unsupported centrality measure: %s" % method)
scores = [(node["name"], node["centrality"]) for node in graph.vs]
if not options.arbitrary:
# Sort by rank/frequency first
scores = sorted(scores, key=lambda item: get_rank(item[0]))
# Then sort by centrality
scores = sorted(scores, key=itemgetter(1), reverse=True)
# graph.vs["label"] = ["%s (%s)" % (node["name"], node["centrality"]) for node in graph.vs]
return scores
def log(*args):
print(*args, sep="\t", file=output)
def printtable(*args):
print(*args, sep="\t", file=tableout)
if __name__ == '__main__':
options = get_arguments()
tiebreak = "arbitrary" if options.arbitrary else "rank"
model_name = options.infile.split("/")[0]
print("loading model %s from %s" % (options.infile, model_name))
model = models.KeyedVectors.load_word2vec_format(options.infile, binary=False)
# model = models.KeyedVectors.load_word2vec_format(options.infile, binary=False, limit=80000)
dictionary = set(model.index2word)
# Only get mid frequent terms
midfrequent = set(model.index2word[1000:100000])
synsets = get_synsets()
figure = plot.figure(figsize=(8, 5.5))
plot.axis(ymax=0.5)
thresholds = numpy.linspace(0, 1, num=100)
variant_description = "%s-%s-%s" % (model_name, tiebreak, options.ties) # for logging
tablefile = "../data/table-%s-%s.tsv" % (model_name, tiebreak)
with open(tablefile, "w") as tableout:
printtable("Model", "Centrality", "Accuracy", "Best-epsilon", "FC", "acc-FC", "acc-no-FC", "ties", "acc-ties")
for method in ["pagerank", "degree", "betweenness"]:
logfile = "../data/%s-%s.tsv" % (variant_description, method)
with open(logfile, "w") as output:
log("Hypernym", "Center", "Hypernym-Centrality", "Center-Centrality", "Difference", "Centrality")
scores = []
for threshold in thresholds:
score = [wordnet_subgraph(model, threshold, synset, method) for synset in synsets]
score = numpy.average(score)
scores.append(score)
plot.plot(thresholds, scores, label=method)
best_score = numpy.max(scores)
best_threshold = thresholds[numpy.argmax(scores)]
# Plot graphs for best threshold
fully_connected = []
excluding_fully_connected = []
ties = []
baseline_random = []
baseline_rank = []
for synset in synsets:
wordnet_subgraph(model, best_threshold, synset, method, plot=True)
log("best threshold: %.3f, score: %.3f\t\t\t\t\t" % (best_threshold, best_score))
log("ties: %d, score ties: %.3f\t\t\t\t\t" % (len(ties), numpy.average(ties)))
log("fully connected: %d, score fully connected: %.3f\t\t\t\t\t" % (len(fully_connected), numpy.average(fully_connected)))
log("not fully connected: %d, score excluding fully connected: %.3f\t\t\t\t\t" % (len(excluding_fully_connected), numpy.average(excluding_fully_connected)))
log("random baseline: %.3f\t\t\t\t\t" % (numpy.average(baseline_random)))
log("most frequent baseline: %.3f\t\t\t\t\t" % (numpy.average(baseline_rank)))
printtable(model_name, method, best_score, best_threshold, len(fully_connected), numpy.average(
fully_connected), numpy.average(excluding_fully_connected), len(ties), numpy.average(ties))
# plot baselines
plot.plot(thresholds, [numpy.average(baseline_random) for threshold in thresholds], linestyle="--", label="random baseline")
plot.plot(thresholds, [numpy.average(baseline_rank) for threshold in thresholds], linestyle=":", label="frequency baseline")
plot.xlabel("Threshold")
plot.ylabel("Average score")
plot.legend(title="Centrality measure")
plot.savefig("temp/images/centrality-%s.pdf" % variant_description)