-
Notifications
You must be signed in to change notification settings - Fork 7
/
k_pg_proxy
executable file
·291 lines (226 loc) · 8.9 KB
/
k_pg_proxy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#!/usr/bin/env ruby
require "base64"
require "digest"
require "json"
require "openssl"
require "securerandom"
require "socket"
PROXY_PORT = 10_000
THREADS = 10
CONTEXT = ARGV.first || `kubectl config current-context`.strip
def parse_startup_message(client_socket)
length = client_socket.read(4).unpack("L>").first - 4
version = client_socket.read(4).unpack("L>").first
if version == 80877103
puts "Handling SSLRequest"
client_socket.write("N") # we're not accepting SSL
length = client_socket.read(4).unpack("L>").first
_version = client_socket.read(4).unpack("L>").first
end
puts "Handling StartupMessage"
# "user\x00username\x00database\x00dbname\x00\x00" -> { "user" => "username", "database" => "dbname", ... }
client_socket.read(length - 8).split("\x00").each_slice(2).to_a.to_h
end
def send_startup_message(pg_socket, user:, database:)
puts "Sending StartupMessage to #{database}"
message = [0, 196608]
message_size = 4 + 4
pack_string = "L>L>"
["user", user, "database", database].each do |param|
message << param
message_size += param.size + 1
pack_string << "Z*"
end
message << 0
message_size += 1
pack_string << "C"
message[0] = message_size
startup_message = message.pack(pack_string)
pg_socket.write(startup_message)
end
def handle_authentication(pg_socket, password:)
puts "Handling AuthenticationSASL"
char_tag = pg_socket.read(1)
if char_tag == "E"
length = pg_socket.read(4).unpack1("L>") - 4
payload = pg_socket.read(length)
puts "Postgres Error:"
puts payload.split("\x00")
return
end
raise "expected authentication response but got #{char_tag}" unless char_tag == "R"
length = pg_socket.read(4).unpack1("L>") - 4
type = pg_socket.read(4).unpack1("L>")
payload = pg_socket.read(length - 4)
scram_sha_256 = "SCRAM-SHA-256"
unless type == 10 && payload.include?(scram_sha_256)
raise "don't know how to handle authentication request type #{type} with payload #{payload}"
end
puts "Sending SASLInitialResponse"
nonce = SecureRandom.urlsafe_base64(18)
first_message = "n,,n=,r=#{nonce}"
length = 4 + scram_sha_256.bytesize + 1 + 4 + first_message.bytesize
sasl_initial_response = ["p", length, scram_sha_256, first_message.bytesize, first_message]
message = sasl_initial_response.pack("ZL>Z*L>a*")
pg_socket.write(message)
puts "Handling AuthenticationSASLContinue"
char_tag = pg_socket.read(1)
if char_tag == "E"
length = pg_socket.read(4).unpack1("L>") - 4
payload = pg_socket.read(length)
puts "Postgres Error:"
puts payload.split("\x00")
raise
end
raise "expected authentication response but got #{char_tag}" unless char_tag == "R"
length = pg_socket.read(4).unpack1("L>") - 4
type = pg_socket.read(4).unpack1("L>")
raise "expected authentication response but got #{type}" unless type == 11
payload = pg_socket.read(length - 4)
puts "Sending SASLResponse"
# "r=foo,s=bar==,i=4096" -> { "r" => "foo", "s" => "bar==", "i" => "4096" }
params = payload.split(",").map { |pair| pair.split("=", 2) }.to_h
r = params.fetch("r")
s = params.fetch("s")
i = params.fetch("i").to_i
raise "expected authentication response to start with #{nonce} but it was #{r}" unless r.start_with?(nonce)
# biws == base64 of "n,,"
final_msg_without_proof = "c=biws,r=#{r}"
digest = OpenSSL::Digest.new("SHA256")
salted_pass = OpenSSL::PKCS5.pbkdf2_hmac(password, Base64.strict_decode64(s), i, digest.digest_length, digest)
client_key = OpenSSL::HMAC.digest("sha256", salted_pass, "Client Key")
auth_msg = "n=,r=#{nonce},#{payload},#{final_msg_without_proof}"
client_sig = OpenSSL::HMAC.digest("sha256", OpenSSL::Digest.new("SHA256").update(client_key).digest, auth_msg)
# These could be used for verifying the final server message but YAGNI
# server_key = OpenSSL::HMAC.digest("sha256", salted_pass, "Server Key")
# server_sig = OpenSSL::HMAC.digest("sha256", server_key, auth_msg)
client_key_ints = client_key.unpack("C*")
client_sig_ints = client_sig.unpack("C*")
whatever = client_key_ints.map.with_index { |client_key_int, index| client_key_int ^ client_sig_ints[index] }
proof = Base64.strict_encode64(whatever.pack("C*"))
sasl_response_string = "#{final_msg_without_proof},p=#{proof}"
sasl_response = ["p", sasl_response_string.bytesize + 4, sasl_response_string]
message = sasl_response.pack("ZL>a*")
pg_socket.write(message)
puts "Handling AuthenticationSASLFinal"
char_tag = pg_socket.read(1)
if char_tag == "E"
length = pg_socket.read(4).unpack1("L>") - 4
payload = pg_socket.read(length)
puts "Postgres Error!"
puts payload.split("\x00")
raise
end
length = pg_socket.read(4).unpack1("L>") - 4
type = pg_socket.read(4).unpack1("L>")
raise "expected AuthenticationSASLFinal response but got #{char_tag} / #{type}" unless type == 12
pg_socket.read(length - 4) # we could verify the final response but YAGNI
puts "Handling AuthenticationOk"
char_tag = pg_socket.read(1)
length = pg_socket.read(4).unpack1("L>") - 4
type = pg_socket.read(length).unpack1("L>")
raise "expected AuthenticationOk response but got #{char_tag} / #{type}" unless type == 0
true
end
def forward(from_socket, to_socket)
loop do
data = from_socket.recv(1024)
return if data.nil?
return if data == "X\x00\x00\x00\x04" # Terminate message
to_socket.send(data, 0)
rescue Errno::ECONNRESET, IOError => e
# `recv': stream closed in another thread (IOError)
# `recv': closed stream (IOError)
puts e.message unless e.message[/closed in another thread/]
return
end
end
def handle_connection(client_socket, connection_number)
params = parse_startup_message(client_socket)
database = params.fetch("database")
# Start port forward and connect to Kubernetes Postgres
primary_pod = `kubectl --context #{CONTEXT} get pod -o name -l postgres-operator.crunchydata.com/cluster=#{database},postgres-operator.crunchydata.com/role=master`.chomp # rubocop:disable Layout/LineLength
primary_pod = `kubectl --context #{CONTEXT} get pod -o name -l cnpg.io/cluster==#{database},cnpg.io/instanceRole=primary`.chomp if primary_pod.empty?
if primary_pod.empty?
$stderr.puts "Error: no primary postgres pod found for #{database}"
client_socket.close
return
end
port_forward_port = PROXY_PORT + connection_number
port_forward_pid = spawn(
"kubectl --context #{CONTEXT} port-forward #{primary_pod} #{port_forward_port}:5432",
err: File::NULL,
)
Process.detach(port_forward_pid)
sleep 1 # TODO: proper test to see if kubectl port-forward is ready
pg_socket = TCPSocket.new("localhost", port_forward_port)
# Pretend AuthenticationOk to avoid client giving up prematurely
authentication_ok = ["R", 8, 0].pack("aL>L>")
client_socket.write(authentication_ok)
cluster = `kubectl --context #{CONTEXT} get cluster #{database} -o json`
pgo = false
if cluster.empty?
cluster = `kubectl --context #{CONTEXT} get postgrescluster #{database} -o json`
pgo = true
end
abort "Error: cluster '#{database}' not found" if cluster.empty?
cluster = JSON.parse(cluster)
if pgo
user ||= cluster.dig("spec", "users", 0, "name")
unless user
puts "No users found in PostgresCluster spec, using default user '#{database}'"
user = database
end
secret_suffix = "pguser-#{user}"
else
user = database
secret_suffix = "app"
end
secret = JSON.parse(`kubectl --context #{CONTEXT} get secret #{database}-#{secret_suffix} -o json`).fetch("data")
send_startup_message(
pg_socket,
user: Base64.strict_decode64(secret.fetch("user")),
database: Base64.strict_decode64(secret.fetch("dbname")),
)
success = handle_authentication(pg_socket, password: Base64.strict_decode64(secret.fetch("password")))
unless success
client_socket.close
pg_socket.close
Process.kill("QUIT", port_forward_pid)
puts "Error: Failed to connect to #{database}"
return
end
client_forward = Thread.new { forward(client_socket, pg_socket) }
pg_forward = Thread.new { forward(pg_socket, client_socket) }
loop do
if client_forward.alive? && pg_forward.alive?
sleep 1
else
puts "Disconnecting from #{database}"
return
end
end
rescue Errno::ECONNRESET
puts "handle_connection: Errno::ECONNRESET handling database #{database}"
ensure
Process.kill("QUIT", port_forward_pid) rescue nil # rubocop:disable Style/RescueModifier
client_socket&.close
pg_socket&.close
end
puts "Listening for Postgres connections on localhost:#{PROXY_PORT}"
puts "Just pass the name of the kubernetes database and leave the rest to me!"
puts ""
puts "EXAMPLE:"
puts "psql -h localhost -p #{PROXY_PORT} -d mynewsdesk-staging"
server = TCPServer.new(PROXY_PORT)
connection_number = 0
loop do
puts ""
client_socket = server.accept
connection_number += 1
Thread.new { handle_connection(client_socket, connection_number) }
rescue Interrupt
puts ""
puts "CTRL+C received, exiting..."
exit
end