-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
187 lines (139 loc) · 3.91 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import spacy
from spacy import displacy
from pathlib import Path
import webbrowser
import time
import pandas as pd
import io
import string
import traceback
import os
#from cairosvg import svg2png
#spacy.require_gpu() / spacy.prefer_gpu()
nlp = spacy.load("en_core_web_trf")
nlp.max_length = 10000000
def printConll(sent,k, file):
out = ''
#svg = displacy.render(sent, style="dep", jupyter=False)
#svg2png(bytestring = svg, write_to="Out/pdf/"+ file +"/dependency_plot"+str(k)+".png")
for i, word in enumerate(sent):
if word.head == word:
head_idx = 0
else:
head_idx = word.head.i - sent[0].i + 1
out +="%d\t%s\t%s\t%s\t%s\t%s\n"%(
i+1, # There's a word.i attr that's position in *doc*
word,
word.lemma_,
word.tag_, # Fine-grained tag
str(head_idx),
word.dep_ # Relation
)
return out
def parse_txt(data, file):
print(file)
res = []
count_short = 0
count_err = 0
doc = nlp(data)
lista = list(doc.sents)
print(len(lista))
k = 0
count = 0
for sent in lista:
k +=1
bug = False
try:
out = printConll(sent,k, file)
conll = io.StringIO(out)
df = pd.read_csv(conll, header = None, sep = '\t')
if len(df) >1:
#remove punctuation marks
pun = df.index[df[5] == 'punct'].tolist()
pun = sorted(pun)
el_pun = [x+1 for x in pun]
#If the root is a punctuation mark, the sentence is invalid
for i in range(0, len(pun)):
if df[2][pun[i]] == 0:
bug = True
#rescale in case punctuation are not leaves
for j in range(0, len(df[4])):
if df[4][j]!= 0:
if df[4][j] in el_pun:
prec = df[4][j]
df[4][j] = df[4][prec-1]
for j in range(0, len(df[4])):
num_v = sum(df[4][j] >= i+1 for i in pun)
df[4][j] = df[4][j]-num_v
df.drop(pun, inplace = True)
df = df.reset_index(drop=True)
df[0] = df.index +1
l = df[4].astype(int).tolist()
l_in = range(1, len(l)+1)
#Sanity checks: Exactly a root, feasible maximum value.
if len(l) >3:
root = l.index(0) +1
if root not in l:
bug = True
for i in range(len(l)):
if l[i]==l_in[i]:
bug = True
break
if max(l) > max(l_in):
bug = True
if l.count(0) >1 :
bug = True
if l.count(0) != 1:
bug = True
if not bug:
res.append(l)
else:
count_short +=1
except:
count_err +=1
#Trasform treebank in a string
res1 = []
for i in res:
res1.append(" ".join(str(x) for x in i))
res = res1
return res
########################
files = [i for i in os.listdir("./books") if i.endswith("txt")]
files_new = []
for file in files:
file = os.path.splitext(file)[0]
files_new.append(file)
files = files_new
for file in files:
filetxt = file+ ".txt"
fp = open("books/" +filetxt,encoding="windows-1252")
data = fp.read()
#Create monolithic string
data = data.replace("\n", " ")
data = data.replace("\t", " ")
data = ' '.join(data.split())
chunks = 1
dati = []
k = 0
#avoid spacy overloading by splitting large files
if(len(data) >500000):
firstpart, secondpart = data[:len(data)//2], data[len(data)//2:]
dati.append(firstpart)
dati.append(secondpart)
chunks = 2
else:
dati.append(data)
open("treebanks-demo/out-" +filetxt, 'w').close()
print("chunks: " + str(chunks))
for i in range (0,chunks):
start = time.time()
res = parse_txt(dati[i],file)
end = time.time()
print("Elapsed time = %s" % (end - start))
for i in range(0, len(res)):
with open("treebanks-demo/out-" +filetxt, 'a') as f:
if res[i] != '' and len(res[i]) > 2:
f.write("%s " % res[i])
f.write("\n")