-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfix_distribution_policy.py
executable file
·287 lines (254 loc) · 10.9 KB
/
fix_distribution_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# -*- coding: utf-8 -*-
import argparse
import re
from multiprocessing import Process, Queue
import signal
import time
from pygresql.pg import DB
import sys
procs = []
total_leafs = 0
total_norms = 0
total_roots = 0
total_norm_size = 0
total_root_size = 0
def sig_handler(sig, arg):
global procs
for proc in procs:
try:
proc.terminate()
proc.join()
except Exception as e:
sys.stderr.write("Error while terminating process: %s\n" % str(e))
sys.stderr.write("terminated by signal %s\n" % sig)
sys.exit(127)
class ChangePolicy(object):
def __init__(self, dbname, port, host, user, dump_legacy_ops, order_size_ascend):
self.dbname = dbname
self.port = int(port)
self.host = host
self.user = user
self.dump_legacy_ops = dump_legacy_ops
self.order_size_ascend = order_size_ascend
self.pt = re.compile(r'[(](.*)[)]')
def get_db_conn(self):
db = DB(dbname=self.dbname,
port=self.port,
host=self.host,
user=self.user)
return db
def get_regular_tables(self):
db = self.get_db_conn()
predict = '' if self.dump_legacy_ops else 'not'
sql = """
select
'"' || pn.nspname || '"."' || pc.relname || '"' as relname,
pg_get_table_distributedby(pc.oid) distby
from pg_class pc,
pg_namespace pn,
gp_distribution_policy gdp
where pc.oid = gdp.localoid and
pn.oid = pc.relnamespace and
(not pc.relhassubclass) and
pc.oid not in (select parchildrelid from pg_partition_rule) and
gdp.policytype = 'p' and array_length(gdp.distkey::int[], 1) > 0 and
%s (gdp.distclass::oid[] && '{10165,10166,10167,10168,10169,10170,10171,10172,10173,10174,10175,10176,10177,10178,10179,10180,10181,10182,10183,10184,10185,10186,10187,10188,10189,10190,10191,10192,10193,10194,10195,10196,10197,10198}'::oid[])
""" % predict
r = db.query(sql).getresult()
db.close()
return r
def get_root_partition_tables(self, is_legacy=True):
db = self.get_db_conn()
predict = '' if self.dump_legacy_ops else 'not'
sql = """
select
'"' || pn.nspname || '"."' || pc.relname || '"' as relname,
pg_get_table_distributedby(pc.oid) distby
from pg_class pc,
pg_namespace pn,
gp_distribution_policy gdp
where pc.oid = gdp.localoid and
pn.oid = pc.relnamespace and
pc.relhassubclass and
pc.oid not in (select parchildrelid from pg_partition_rule) and
gdp.policytype = 'p' and array_length(gdp.distkey::int[], 1) > 0 and
%s (gdp.distclass::oid[] && '{10165,10166,10167,10168,10169,10170,10171,10172,10173,10174,10175,10176,10177,10178,10179,10180,10181,10182,10183,10184,10185,10186,10187,10188,10189,10190,10191,10192,10193,10194,10195,10196,10197,10198}'::oid[])
""" % predict
r = db.query(sql).getresult()
db.close()
return r
def remove_ops_ifany(self, distby):
# DISTRIBUTED BY (a cdbhash_int4_ops, b cdbhash_int4_ops)
t = self.pt.findall(distby)[0]
cols = ", ".join([s.strip()
for s in t.split(',')])
return "distributed by (%s)" % cols
def handle_one_table(self, name, distby):
new_distby = self.remove_ops_ifany(distby)
sql = """
alter table %s set with (reorganize=true) %s;
""" % (name, new_distby)
return sql.strip()
def dump_table_info(self, db, name, is_normal=True):
if is_normal:
sql = "select pg_relation_size('{name}'::regclass);"
r = db.query(sql.format(name=name)).getresult()
global total_norms
global total_norm_size
total_norm_size += r[0][0]
total_norms += 1
return "normal table, size %s" % r[0][0], r[0][0]
else:
sql_size = """
with recursive cte(nlevel, table_oid) as (
select 0, '{name}'::regclass::oid
union all
select nlevel+1, pi.inhrelid
from cte, pg_inherits pi
where cte.table_oid = pi.inhparent
)
select sum(pg_relation_size(table_oid))
from cte where nlevel = (select max(nlevel) from cte);
"""
r = db.query(sql_size.format(name=name))
size = r.getresult()[0][0]
sql_nleafs = """
with recursive cte(nlevel, table_oid) as (
select 0, '{name}'::regclass::oid
union all
select nlevel+1, pi.inhrelid
from cte, pg_inherits pi
where cte.table_oid = pi.inhparent
)
select count(1)
from cte where nlevel = (select max(nlevel) from cte);
"""
r = db.query(sql_nleafs.format(name=name))
nleafs = r.getresult()[0][0]
global total_leafs
global total_roots
global total_root_size
total_root_size += size
total_leafs += nleafs
total_roots += 1
return "partition table, %s leafs, size %s" % (nleafs, size), size
def dump(self, fn):
db = self.get_db_conn()
f = open(fn, "w")
print>>f, "-- dump %s ops " % 'legacy' if self.dump_legacy_ops else 'new'
print>>f, "-- order table by size in %s order " % 'ascending' if self.order_size_ascend else 'descending'
table_info = []
# regular tables
regular = self.get_regular_tables()
for name, distby in regular:
msg, size = self.dump_table_info(db, name)
table_info.append((name, distby, size, msg))
# partitioned tables
parts = self.get_root_partition_tables()
for name, distby in parts:
msg, size = self.dump_table_info(db, name, False)
table_info.append((name, distby, size, msg))
if self.order_size_ascend:
table_info.sort(key=lambda x: x[2], reverse=False)
else:
table_info.sort(key=lambda x: x[2], reverse=True)
for name, distby, size, msg in table_info:
print>>f, "-- ", msg
print>>f, self.handle_one_table(name, distby)
print>>f
f.close()
class ConcurrentRun(object):
def __init__(self, dbname, port, host, user, script_file, nproc):
self.dbname = dbname
self.port = int(port)
self.host = host
self.user = user
self.script_file = script_file
self.nproc = nproc
def get_db_conn(self):
db = DB(dbname=self.dbname,
port=self.port,
host=self.host,
user=self.user)
return db
def parse_inputfile(self):
self.sqls = Queue()
with open(self.script_file) as f:
for line in f:
sql = line.strip()
if (sql.startswith("alter table") and
sql.endswith(";") and
sql.count(";") == 1):
self.sqls.put(sql)
def run(self):
self.parse_inputfile()
global procs
procs = []
for i in range(self.nproc):
proc = Process(target=ConcurrentRun.alter,
args=[self.sqls, i, self.nproc,
self.dbname, self.port, self.host, self.user])
procs.append(proc)
for proc in procs:
proc.start()
for proc in procs:
proc.join()
@staticmethod
def alter(sqls, idx, nproc, dbname, port, host, user):
import logging
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout,
format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()
logger.info("worker[%d]: begin: " % idx)
logger.info("worker[%d]: connect to <%s> ..." % (idx, dbname))
db = DB(dbname=dbname,
port=port,
host=host,
user=user)
start = time.time()
while not sqls.empty():
sql = sqls.get()
logger.info("worker[%d]: execute alter command \"%s\" ... " % (idx, sql))
db.query(sql)
tab = sql.strip().split()[2]
analyze_sql = "analyze %s;" % tab
logger.info("worker[%d]: execute analyze command \"%s\" ... " % (idx, analyze_sql))
db.query(analyze_sql)
end = time.time()
total_time = end - start
logger.info("Current progress: have %d remaining, %.3f seconds passed." % (sqls.qsize(), total_time))
db.close()
logger.info("worker[%d]: finish." % idx)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='fix_distribution_policy')
parser.add_argument('--host', type=str, help='Greenplum Database hostname')
parser.add_argument('--port', type=int, help='Greenplum Database port')
parser.add_argument('--dbname', type=str, help='Greenplum Database database name')
parser.add_argument('--user', type=str, help='Greenplum Database user name')
subparsers = parser.add_subparsers(help='sub-command help', dest='cmd')
parser_gen = subparsers.add_parser('gen', help='generate alter table cmds')
parser_run = subparsers.add_parser('run', help='run the alter table cmds')
parser_gen.add_argument('--out', type=str, help='outfile path for the alter table commands')
parser_gen.add_argument('--dump_legacy_ops', action='store_true', help='dump all tables with legacy distkey ops')
parser_gen.set_defaults(dump_legacy_ops=False)
parser_gen.add_argument('--order_size_ascend', action='store_true', help='sort the tables by size in ascending order')
parser_gen.set_defaults(order_size_ascend=False)
parser_run.add_argument('--nproc', type=int, default=1, help='the concurrent proces to run the alter table commands')
parser_run.add_argument('--input', type=str, help='the file contains all alter table commands')
args = parser.parse_args()
if args.cmd == 'gen':
cp = ChangePolicy(args.dbname, args.port, args.host, args.user, args.dump_legacy_ops, args.order_size_ascend)
cp.dump(args.out)
print "total table size (in GBytes) : %s" % (float(total_norm_size + total_root_size) / 1024.0**3)
print "total normal table : %s" % total_norms
print "total partition tables : %s" % total_roots
print "total leaf partitions : %s" % total_leafs
elif args.cmd == "run":
signal.signal(signal.SIGTERM, sig_handler)
signal.signal(signal.SIGINT, sig_handler)
cr = ConcurrentRun(args.dbname, args.port, args.host, args.user,
args.input, args.nproc)
cr.run()
else:
sys.stderr.write("unknown subcommand!")
sys.exit(127)