-
Notifications
You must be signed in to change notification settings - Fork 2
/
s2.py
127 lines (112 loc) · 3.54 KB
/
s2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import re
import argparse
import collections
initialWeight = 1
stepSize = 0
def startWithTag(tag):
return "%-15s %s"%("S2", "_" + tag)
def endWithTag(tag):
return "%-15s %s"%("_" + tag, tag)
def tagAFollowedByTagB(tagA, tagB):
return "%-15s %-15s %s"%("_" + tagA, tagA, "_" + tagB)
def sanitize(nont):
nont = nont.strip()
if nont == '.':
nont = 'PERIOD'
if nont == ':':
nont = 'COLON'
if nont == ',':
nont = 'COMMA'
if nont == "''":
nont = 'TWOSINGLEQUOTES'
if nont == "``":
nont = 'TWOGRAVES'
if nont == '(':
nont = '-LRB-'
if nont == ')':
nont = '-RRB-'
return nont
def main(vocab_file, tree_file, s2_file):
fileIn = open(vocab_file, "r")
rules = fileIn.readlines()
fileIn.close()
tags = []
for rule in rules:
if rule.strip() == '':
continue
tag = re.split('\s+', rule)[1]
if tag not in tags:
tags.append(tag)
# Initialization
weights = collections.OrderedDict()
for tag in tags:
weights[startWithTag(tag)] = initialWeight
for tagA in tags:
weights[endWithTag(tagA)] = initialWeight
for tagB in tags:
weights[tagAFollowedByTagB(tagA, tagB)] = initialWeight
# Count the weight from the tree file
fileIn = open(tree_file, "r")
count = 0
lastTermTag = None
for line in fileIn.readlines():
i = 0
l = len(line) - 1 # The last character is the new line
while i < l:
if line[i] == '(':
# 1. We get to a terminal like (COLON ;) or
# 2. We get to a non-terminal line (S
count += 1
curTag = ""
i += 1
if line[i] in {'(', ')'}:
# Terminals (( () or () ))
curTag = sanitize(line[i])
if lastTermTag:
weights[tagAFollowedByTagB(lastTermTag, curTag)] += stepSize
else:
weights[startWithTag(curTag)] += stepSize
# print(str(lastTermTag) + " " + curTag, end=',')
lastTermTag = curTag
i += 3
count -= 1
else:
while i < l and line[i] != ' ':
curTag = curTag + line[i]
i += 1
if i < l:
i += 1
if line[i] != '(':
# Terminals
curTag = sanitize(curTag)
if lastTermTag:
weights[tagAFollowedByTagB(lastTermTag, curTag)] += stepSize
else:
weights[startWithTag(curTag)] += stepSize
# print(str(lastTermTag) + " " + curTag, end=',')
lastTermTag = curTag
while i < l and line[i] != ')':
# We care only about the tags
i += 1
count -= 1
else:
i -= 1
elif line[i] == ')':
count -= 1
if count == 0 and lastTermTag:
weights[endWithTag(lastTermTag)] += stepSize
lastTermTag = None
# print()
i += 1
fileIn.close()
fileOut = open(s2_file, "w+")
for rule in weights:
fileOut.write("%-10d %s\n"%(weights[rule], rule))
fileOut.close()
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument("-v", "--vocab", dest="vocab_file", required=True, help="input vocab gr file")
ap.add_argument("-tree", "--tree", dest="tree_file", required=True, help="input tree file")
ap.add_argument("-s2", "--s2", dest="s2_file", required=True, help="output s2 gr file")
args = ap.parse_args()
main(args.vocab_file, args.tree_file, args.s2_file)