-
Notifications
You must be signed in to change notification settings - Fork 15
/
ground_truth_tools.py
302 lines (253 loc) · 9.85 KB
/
ground_truth_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#!/usr/bin/env python2.7
# encoding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import gzip
from itertools import chain
import json
import logging
from os import path
import random
import nose
from rdflib import URIRef
from rdflib.util import from_n3
import six
import config
WP_LINKER_RES_FILENAME = 'wp_linker_results.json.gz'
logger = logging.getLogger(__name__)
_wp_linker_results = None
def load_linker_results(fn=None):
global _wp_linker_results
if fn is None:
fn = path.join(
path.dirname(__file__), config.DATADIR, WP_LINKER_RES_FILENAME)
if not _wp_linker_results:
with gzip.open(fn) as f:
_wp_linker_results = json.load(f)
return _wp_linker_results
def get_verified_mappings():
"""Only returns results which are out of question verified (3 positive)."""
wp_linker_results = load_linker_results()
res = {}
for hash_, mapping in six.iteritems(wp_linker_results['ratings']):
ratings = mapping['ratings']
if (
'False' not in ratings and
'Skip' not in ratings and
ratings.get('True', 0) > 2
):
res[hash_] = mapping
return res
def wiki_to_dbpedia_link(wikilink):
return wikilink.replace(
'http://en.wikipedia.org/wiki/', 'http://dbpedia.org/resource/', 1)
def get_dbpedia_links_from_mappings(mappings):
links = set()
for _, mapping in six.iteritems(mappings):
for link_kind in ['stimulus_link', 'response_link']:
link = mapping[link_kind]
links.add(wiki_to_dbpedia_link(link))
return sorted(links)
def get_dbpedia_pairs_from_mappings(mappings):
pairs = set()
for _, mapping in six.iteritems(mappings):
stimulus_link = wiki_to_dbpedia_link(mapping['stimulus_link'])
response_link = wiki_to_dbpedia_link(mapping['response_link'])
pairs.add((stimulus_link, response_link))
return sorted(pairs)
# noinspection PyPep8Naming
def URIRefify(links):
return tuple([URIRef(l) for l in links])
def get_semantic_associations(
fn=None, limit=None, swap_source_target=False, drop_invalid=False
):
if not fn:
verified_mappings = get_verified_mappings()
semantic_associations = get_dbpedia_pairs_from_mappings(
verified_mappings)
semantic_associations = [URIRefify(p) for p in semantic_associations]
else:
semantic_associations = []
with gzip.open(fn) if fn.endswith('.gz') else open(fn) as f:
# expects a file with one space separated pair of n3 encoded IRIs
# per line
r = csv.DictReader(
f,
delimiter=b' ',
doublequote=False,
escapechar=None,
quoting=csv.QUOTE_NONE,
)
assert r.fieldnames == ['source', 'target']
for i, row in enumerate(r):
if limit and i >= limit:
break
source = from_n3(row['source'].decode('UTF-8'))
target = from_n3(row['target'].decode('UTF-8'))
for x in (source, target):
# noinspection PyBroadException
try:
x.n3()
except Exception:
if drop_invalid:
logger.warning(
'ignoring ground truth pair %r: %r cannot be '
'serialized as N3',
(row['source'], row['target']), x
)
break
else:
logger.error(
'error in ground truth pair %r: %r cannot be '
'serialized as N3',
(row['source'], row['target']), x
)
raise
else:
semantic_associations.append((source, target))
if swap_source_target:
logger.info('swapping all (source, target) pairs: (s,t) --> (t,s)')
semantic_associations = [(t, s) for s, t in semantic_associations]
return semantic_associations
def write_semantic_associations(associations, fn=None):
if fn is None:
fn = config.GT_ASSOCIATIONS_FILENAME
with open(fn, 'w') as f:
# writes a file with one space separated pair of n3 encoded IRIs
# per line
w = csv.DictWriter(
f,
fieldnames=('source', 'target'),
delimiter=b' ',
doublequote=False,
escapechar=None,
quoting=csv.QUOTE_NONE,
)
w.writeheader()
for source, target in associations:
w.writerow({
'source': source.n3().encode('UTF-8'),
'target': target.n3().encode('UTF-8'),
})
def filter_node_pairs_split(train, test, variant):
assert variant in config.SPLITTING_VARIANTS
if variant == 'target_node_disjoint':
train_target_nodes = {t for s, t in train}
tmp = [(s, t) for s, t in test if t not in train_target_nodes]
logger.info(
'removed %d/%d pairs from test set because of overlapping target '
'nodes with training set',
len(test) - len(tmp), len(test)
)
test = tmp
elif variant == 'node_disjoint':
train_nodes = {n for np in train for n in np}
tmp = [
(s, t) for s, t in test
if s not in train_nodes and t not in train_nodes
]
logger.info(
'removed %d/%d pairs from test set because of overlapping '
'nodes with training set',
len(test) - len(tmp), len(test)
)
test = tmp
return train, test
@nose.tools.nottest
def split_training_test_set(associations, split=0.1, seed=42, variant='random'):
return next(
k_fold_cross_validation(associations, int(1 / split), seed, variant)
)
def k_fold_cross_validation(associations, k, seed=42, variant='random'):
"""Generates k folds of train and validation sets out of associations.
>>> list(
... k_fold_cross_validation(range(6), 3)
... ) # doctest: +NORMALIZE_WHITESPACE
[([4, 1, 0, 3], [2, 5]), ([2, 5, 0, 3], [4, 1]), ([2, 5, 4, 1], [0, 3])]
"""
assert variant in config.SPLITTING_VARIANTS
assert len(associations) >= k
associations = list(associations) # don't modify input with inplace shuffle
r = random.Random(seed)
r.shuffle(associations)
part_len = len(associations) / k
partitions = []
for i in range(k):
start_idx = int(i * part_len)
end_idx = int((i + 1) * part_len)
partitions.append(associations[start_idx:end_idx])
for i in range(k):
train = list(chain(*(partitions[:i] + partitions[i + 1:])))
val = partitions[i]
yield filter_node_pairs_split(train, val, variant)
def get_20_shuffled_pages_for_local_ground_truth_re_eval(verified_mappings):
"""get 20 shuffled pages.
they were written into:
'./../eat/ver_sem_assocs/verify_semantic_associations_stimuli.txt'
"""
re_test = list({v['stimulus'] for k, v in verified_mappings.items()})
random.shuffle(re_test)
re_test = re_test[:100]
for _ in range(50):
pick_20 = list(re_test)
random.shuffle(pick_20)
pick_20 = pick_20[:20]
for s in pick_20:
print(s)
print()
def get_all_wikipedia_stimuli_for_triplerater(verified_mappings):
# get all wikipedia stimuli for triplerater
fn = './../eat/ver_sem_assocs/verify_semantic_associations_stimuli.txt'
re_test_stimuli = open(fn).read()
re_test_stimuli = [s.strip() for s in re_test_stimuli.split()]
re_test_stimuli = set([s for s in re_test_stimuli if s])
tr_test = list({
v['stimulus_link']
for k, v in verified_mappings.items()
if v['stimulus'] in re_test_stimuli
})
for s in sorted(tr_test):
print(s)
def main():
import numpy as np
import logging.config
logging.basicConfig(level=logging.INFO)
verified_mappings = get_verified_mappings()
# get_dbpedia_pairs_from_mappings(verified_mappings)
# sys.exit(0)
# get_all_wikipedia_stimuli_for_triplerater(verified_mappings)
# sys.exit(0)
# from pprint import pprint
# pprint(verified_mappings)
print("verified mappings {} ({} raw associations)".format(
len(verified_mappings),
sum([int(m['count']) for m in verified_mappings.values()]),
))
sem_assocs = get_semantic_associations(None)
if not path.isfile(config.GT_ASSOCIATIONS_FILENAME):
write_semantic_associations(sem_assocs)
print("created {}".format(config.GT_ASSOCIATIONS_FILENAME))
assert get_semantic_associations(config.GT_ASSOCIATIONS_FILENAME) == \
sem_assocs
# also write individual train and test files and print association strengths
assocs_train, assocs_test = split_training_test_set(sem_assocs)
for t, at in [('train', assocs_train), ('test', assocs_test)]:
fn = config.GT_ASSOCIATIONS_FILENAME.replace('.csv', '_%s.csv' % t)
if not path.isfile(fn):
write_semantic_associations(at, fn=fn)
print("created %s" % fn)
# calculate human groundtruth association strengths from orig mapping
_at = set((str(s), str(t)) for s, t in at)
a = np.array([
int(v['count']) / 100
for v in verified_mappings.values()
if get_dbpedia_pairs_from_mappings({'x': v})[0] in _at
])
print(t, 'avg association strength:', a.mean(), 'stddev', a.std())
print("used for training", len(assocs_train))
print("used for eval", len(assocs_test))
print("overlap train & test", len(set(assocs_train) & set(assocs_test)))
if __name__ == '__main__':
main()