diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2acce1b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.DS_Store +/.build +/Packages +/*.xcodeproj +Package.resolved + diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..f74c0082 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Qutheory, LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Package.swift b/Package.swift new file mode 100644 index 00000000..10ecb7f4 --- /dev/null +++ b/Package.swift @@ -0,0 +1,32 @@ +// swift-tools-version:4.0 +import PackageDescription + +let package = Package( + name: "PostgreSQL", + products: [ + .library(name: "PostgreSQL", targets: ["PostgreSQL"]), + ], + dependencies: [ + // ⏱ Promises and reactive-streams in Swift built for high-performance and scalability. + .package(url: "https://github.com/vapor/async.git", from: "1.0.0-rc"), + + // 🌎 Utility package containing tools for byte manipulation, Codable, OS APIs, and debugging. + .package(url: "https://github.com/vapor/core.git", from: "3.0.0-rc"), + + // πŸ”‘ Hashing (BCrypt, SHA, HMAC, etc), encryption, and randomness. + .package(url: "https://github.com/vapor/crypto.git", from: "3.0.0-rc"), + + // πŸ—„ Core services for creating database integrations. + .package(url: "https://github.com/vapor/database-kit.git", from: "1.0.0-rc"), + + // πŸ“¦ Dependency injection / inversion of control framework. + .package(url: "https://github.com/vapor/service.git", from: "1.0.0-rc"), + + // πŸ”Œ Non-blocking TCP socket layer, with event-driven server and client. + .package(url: "https://github.com/vapor/sockets.git", from: "3.0.0-rc"), + ], + targets: [ + .target(name: "PostgreSQL", dependencies: ["Async", "Bits", "Crypto", "DatabaseKit", "Service", "TCP"]), + .testTarget(name: "PostgreSQLTests", dependencies: ["PostgreSQL"]), + ] +) diff --git a/README.md b/README.md index 11fb31fa..abfc720c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,24 @@ -(See [vapor-community/postgresql](https://github.com/vapor-community/postgresql/) for libpq-based version) +

+ PostgreSQL +
+
+ + Documentation + + + Slack Team + + + MIT License + + + Continuous Integration + + + Swift 4.1 + +

-# Pure-Swift PostgreSQL Library +
-Work in progress +See [vapor-community/postgresql](https://github.com/vapor-community/postgresql/) for `libpq` based version. diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection+Query.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection+Query.swift new file mode 100644 index 00000000..331db76e --- /dev/null +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection+Query.swift @@ -0,0 +1,77 @@ +import Async + +extension PostgreSQLConnection { + /// Sends a parameterized PostgreSQL query command, collecting the parsed results. + public func query( + _ string: String, + _ parameters: [PostgreSQLDataCustomConvertible] = [] + ) throws -> Future<[[String: PostgreSQLData]]> { + var rows: [[String: PostgreSQLData]] = [] + return try query(string, parameters) { row in + rows.append(row) + }.map(to: [[String: PostgreSQLData]].self) { + return rows + } + } + + /// Sends a parameterized PostgreSQL query command, returning the parsed results to + /// the supplied closure. + public func query( + _ string: String, + _ parameters: [PostgreSQLDataCustomConvertible] = [], + resultFormat: PostgreSQLResultFormat = .binary(), + onRow: @escaping ([String: PostgreSQLData]) -> () + ) throws -> Future { + let parameters = try parameters.map { try $0.convertToPostgreSQLData() } + logger?.log(query: string, parameters: parameters) + let parse = PostgreSQLParseRequest( + statementName: "", + query: string, + parameterTypes: parameters.map { $0.type } + ) + let describe = PostgreSQLDescribeRequest(type: .statement, name: "") + var currentRow: PostgreSQLRowDescription? + + return send([ + .parse(parse), .describe(describe), .sync + ]) { message in + switch message { + case .parseComplete: break + case .rowDescription(let row): currentRow = row + case .parameterDescription: break + case .noData: break + default: throw PostgreSQLError(identifier: "query", reason: "Unexpected message during PostgreSQLParseRequest: \(message)", source: .capture()) + } + }.flatMap(to: Void.self) { + let resultFormats = resultFormat.formatCodeFactory(currentRow?.fields.map { $0.dataType } ?? []) + // cache so we don't compute twice + let bind = PostgreSQLBindRequest( + portalName: "", + statementName: "", + parameterFormatCodes: parameters.map { $0.format }, + parameters: parameters.map { .init(data: $0.data) }, + resultFormatCodes: resultFormats + ) + let execute = PostgreSQLExecuteRequest( + portalName: "", + maxRows: 0 + ) + return self.send([ + .bind(bind), .execute(execute), .sync + ]) { message in + switch message { + case .bindComplete: break + case .dataRow(let data): + guard let row = currentRow else { + throw PostgreSQLError(identifier: "query", reason: "Unexpected PostgreSQLDataRow without preceding PostgreSQLRowDescription.", source: .capture()) + } + let parsed = try row.parse(data: data, formatCodes: resultFormats) + onRow(parsed) + case .close: break + case .noData: break + default: throw PostgreSQLError(identifier: "query", reason: "Unexpected message during PostgreSQLParseRequest: \(message)", source: .capture()) + } + } + } + } +} diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection+SimpleQuery.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection+SimpleQuery.swift new file mode 100644 index 00000000..391ac270 --- /dev/null +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection+SimpleQuery.swift @@ -0,0 +1,35 @@ +import Async + +extension PostgreSQLConnection { + /// Sends a simple PostgreSQL query command, collecting the parsed results. + public func simpleQuery(_ string: String) -> Future<[[String: PostgreSQLData]]> { + var rows: [[String: PostgreSQLData]] = [] + return simpleQuery(string) { row in + rows.append(row) + }.map(to: [[String: PostgreSQLData]].self) { + return rows + } + } + + /// Sends a simple PostgreSQL query command, returning the parsed results to + /// the supplied closure. + public func simpleQuery(_ string: String, onRow: @escaping ([String: PostgreSQLData]) -> ()) -> Future { + logger?.log(query: string, parameters: []) + var currentRow: PostgreSQLRowDescription? + let query = PostgreSQLQuery(query: string) + return send([.query(query)]) { message in + switch message { + case .rowDescription(let row): + currentRow = row + case .dataRow(let data): + guard let row = currentRow else { + throw PostgreSQLError(identifier: "simpleQuery", reason: "Unexpected PostgreSQLDataRow without preceding PostgreSQLRowDescription.", source: .capture()) + } + let parsed = try row.parse(data: data, formatCodes: row.fields.map { $0.formatCode }) + onRow(parsed) + case .close: break // query over, waiting for `readyForQuery` + default: throw PostgreSQLError(identifier: "simpleQuery", reason: "Unexpected message during PostgreSQLQuery: \(message)", source: .capture()) + } + } + } +} diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection+TCP.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection+TCP.swift new file mode 100644 index 00000000..7d18b577 --- /dev/null +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection+TCP.swift @@ -0,0 +1,18 @@ +import Async +import TCP + +extension PostgreSQLConnection { + /// Connects to a Redis server using a TCP socket. + public static func connect( + hostname: String = "localhost", + port: UInt16 = 5432, + on worker: Worker, + onError: @escaping TCPSocketSink.ErrorHandler + ) throws -> PostgreSQLConnection { + let socket = try TCPSocket(isNonBlocking: true) + let client = try TCPClient(socket: socket) + try client.connect(hostname: hostname, port: port) + let stream = socket.stream(on: worker, onError: onError) + return PostgreSQLConnection(stream: stream, on: worker) + } +} diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift new file mode 100644 index 00000000..c5221c8f --- /dev/null +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift @@ -0,0 +1,130 @@ +import Async +import Crypto + +/// A PostgreSQL frontend client. +public final class PostgreSQLConnection { + /// Handles enqueued redis commands and responses. + private let queueStream: QueueStream + + /// If non-nil, will log queries. + public var logger: PostgreSQLLogger? + + /// Creates a new Redis client on the provided data source and sink. + init(stream: Stream, on worker: Worker) where Stream: ByteStream { + let queueStream = QueueStream() + + let serializerStream = PostgreSQLMessageSerializer().stream(on: worker) + let parserStream = PostgreSQLMessageParser().stream(on: worker) + + stream.stream(to: parserStream) + .stream(to: queueStream) + .stream(to: serializerStream) + .output(to: stream) + + self.queueStream = queueStream + } + + /// Sends `PostgreSQLMessage` to the server. + func send(_ messages: [PostgreSQLMessage], onResponse: @escaping (PostgreSQLMessage) throws -> ()) -> Future { + var error: Error? + return queueStream.enqueue(messages) { message in + switch message { + case .readyForQuery: + if let e = error { throw e } + return true + case .error(let e): error = e + case .notice(let n): print(n) + default: try onResponse(message) + } + return false // request until ready for query + } + } + + /// Sends `PostgreSQLMessage` to the server. + func send(_ message: [PostgreSQLMessage]) -> Future<[PostgreSQLMessage]> { + var responses: [PostgreSQLMessage] = [] + return send(message) { response in + responses.append(response) + }.map(to: [PostgreSQLMessage].self) { + return responses + } + } + + /// Authenticates the `PostgreSQLClient` using a username with no password. + public func authenticate(username: String, database: String? = nil, password: String? = nil) -> Future { + let startup = PostgreSQLStartupMessage.versionThree(parameters: [ + "user": username, + "database": database ?? username + ]) + var authRequest: PostgreSQLAuthenticationRequest? + return queueStream.enqueue([.startupMessage(startup)]) { message in + switch message { + case .authenticationRequest(let a): + authRequest = a + return true + default: throw PostgreSQLError(identifier: "auth", reason: "Unsupported message encountered during auth: \(message).", source: .capture()) + } + }.flatMap(to: Void.self) { + guard let auth = authRequest else { + throw PostgreSQLError(identifier: "authRequest", reason: "No authorization request / status sent.", source: .capture()) + } + + let input: [PostgreSQLMessage] + switch auth { + case .ok: + guard password == nil else { + throw PostgreSQLError(identifier: "trust", reason: "No password is required", source: .capture()) + } + input = [] + case .plaintext: + guard let password = password else { + throw PostgreSQLError(identifier: "password", reason: "Password is required", source: .capture()) + } + let passwordMessage = PostgreSQLPasswordMessage(password: password) + input = [.password(passwordMessage)] + case .md5(let salt): + guard let password = password else { + throw PostgreSQLError(identifier: "password", reason: "Password is required", source: .capture()) + } + guard let passwordData = password.data(using: .utf8) else { + throw PostgreSQLError(identifier: "passwordUTF8", reason: "Could not convert password to UTF-8 encoded Data.", source: .capture()) + } + + guard let usernameData = username.data(using: .utf8) else { + throw PostgreSQLError(identifier: "usernameUTF8", reason: "Could not convert username to UTF-8 encoded Data.", source: .capture()) + } + + let hasher = MD5() + // pwdhash = md5(password + username).hexdigest() + var passwordUsernameData = passwordData + usernameData + hasher.update(sequence: &passwordUsernameData) + hasher.finalize() + guard let pwdhash = hasher.hash.hexString.data(using: .utf8) else { + throw PostgreSQLError(identifier: "hashUTF8", reason: "Could not convert password hash to UTF-8 encoded Data.", source: .capture()) + } + hasher.reset() + // hash = β€² md 5β€² + md 5(pwdhash + salt ).hexdigest () + var saltedData = pwdhash + salt + hasher.update(sequence: &saltedData) + hasher.finalize() + let passwordMessage = PostgreSQLPasswordMessage(password: "md5" + hasher.hash.hexString) + input = [.password(passwordMessage)] + } + + return self.queueStream.enqueue(input) { message in + switch message { + case .error(let error): throw error + case .readyForQuery: return true + case .authenticationRequest: return false + case .parameterStatus, .backendKeyData: return false + default: throw PostgreSQLError(identifier: "authenticationMessage", reason: "Unexpected authentication message: \(message)", source: .capture()) + } + } + } + } + + /// Closes this client. + public func close() { + queueStream.close() + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLArrayCustomConvertible.swift b/Sources/PostgreSQL/Data/PostgreSQLArrayCustomConvertible.swift new file mode 100644 index 00000000..1bd8ec24 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLArrayCustomConvertible.swift @@ -0,0 +1,128 @@ +import Foundation + +/// Representable by a `T[]` column on the PostgreSQL database. +public protocol PostgreSQLArrayCustomConvertible: PostgreSQLDataCustomConvertible, Codable { + /// The associated array element type + associatedtype PostgreSQLArrayElement: PostgreSQLDataCustomConvertible + + /// Convert an array of elements to self. + static func convertFromPostgreSQLArray(_ data: [PostgreSQLArrayElement]) -> Self + + /// Convert self to an array of elements. + func convertToPostgreSQLArray() -> [PostgreSQLArrayElement] +} + +extension PostgreSQLArrayCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { + return PostgreSQLArrayElement.postgreSQLDataArrayType + } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + guard var value = data.data else { + throw PostgreSQLError(identifier: "nullArray", reason: "Unable to decode PostgreSQL array from `null` data.", source: .capture()) + } + + /// Extract and convert each element. + var array: [PostgreSQLArrayElement] = [] + + let hasData = value.extract(Int32.self).bigEndian + if hasData == 1 { + /// grab the array metadata from the beginning of the data + let metadata = value.extract(PostgreSQLArrayMetadata.self) + for _ in 0.. PostgreSQLData { + let elements = try convertToPostgreSQLArray().map { + try $0.convertToPostgreSQLData() + } + + var data = Data() + data += Int32(1).data // non-null + data += Int32(0).data // b + data += PostgreSQLArrayElement.postgreSQLDataType.raw.data // type + data += Int32(elements.count).data // length + data += Int32(1).data // dimensions + + for element in elements { + if let value = element.data { + data += Int32(value.count).data + data += value + } else { + data += Int32(0).data + } + } + + return PostgreSQLData(type: PostgreSQLArrayElement.postgreSQLDataArrayType, format: .binary, data: data) + } +} + +fileprivate struct PostgreSQLArrayMetadata { + /// Unknown + private let _b: Int32 + + /// The big-endian array element type + private let _type: Int32 + + /// The big-endian length of the array + private let _count: Int32 + + /// The big-endian number of dimensions + private let _dimensions: Int32 + + /// Converts the raw array elemetn type to DataType + var type: PostgreSQLDataType { + return .init(_type.bigEndian) + } + + /// The length of the array + var count: Int32 { + return _count.bigEndian + } + + /// The number of dimensions + var dimensions: Int32 { + return _dimensions.bigEndian + } +} + +extension PostgreSQLArrayMetadata: CustomStringConvertible { + /// See `CustomStringConvertible.description` + var description: String { + return "\(type)[\(count)]" + } +} + +extension Array: PostgreSQLArrayCustomConvertible where Element: Codable, Element: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLArrayCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { + return Element.postgreSQLDataArrayType + } + + /// See `PostgreSQLArrayCustomConvertible.PostgreSQLArrayElement` + public typealias PostgreSQLArrayElement = Element + + /// See `PostgreSQLArrayCustomConvertible.convertFromPostgreSQLArray(_:)` + public static func convertFromPostgreSQLArray(_ data: [Element]) -> Array { + return data + } + + /// See `PostgreSQLArrayCustomConvertible.convertToPostgreSQLArray(_:)` + public func convertToPostgreSQLArray() -> [Element] { + return self + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+BinaryFloatingPoint.swift b/Sources/PostgreSQL/Data/PostgreSQLData+BinaryFloatingPoint.swift new file mode 100644 index 00000000..8276704c --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+BinaryFloatingPoint.swift @@ -0,0 +1,86 @@ +import Foundation + +extension BinaryFloatingPoint { + /// Return's this floating point's bit width. + static var bitWidth: Int { + return exponentBitCount + significandBitCount + 1 + } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { + switch Self.bitWidth { + case 32: return .float4 + case 64: return .float8 + default: fatalError("Unsupported floating point bit width: \(Self.bitWidth)") + } + } + + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { + switch Self.bitWidth { + case 32: return ._float4 + case 64: return ._float8 + default: fatalError("Unsupported floating point bit width: \(Self.bitWidth)") + } + } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + guard let value = data.data else { + throw PostgreSQLError(identifier: "binaryFloatingPoint", reason: "Could not decode \(Self.self) from `null` data.", source: .capture()) + } + switch data.format { + case .binary: + switch data.type { + case .float4: return Self.init(value.makeFloatingPoint(Float.self)) + case .float8: return Self.init(value.makeFloatingPoint(Double.self)) + case .char: return try Self.init(value.makeFixedWidthInteger(Int8.self)) + case .int2: return try Self.init(value.makeFixedWidthInteger(Int16.self)) + case .int4: return try Self.init(value.makeFixedWidthInteger(Int32.self)) + case .int8: return try Self.init(value.makeFixedWidthInteger(Int64.self)) + case .timestamp, .date, .time: + let date = try Date.convertFromPostgreSQLData(data) + return Self(date.timeIntervalSinceReferenceDate) + default: + throw PostgreSQLError( + identifier: "binaryFloatingPoint", + reason: "Could not decode \(Self.self) from binary data type: \(data.type).", + source: .capture() + ) + } + case .text: + let string = try data.decode(String.self) + guard let converted = Double(string) else { + throw PostgreSQLError(identifier: "binaryFloatingPoint", reason: "Could not decode \(Self.self) from string: \(string).", source: .capture()) + } + return Self(converted) + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: Self.postgreSQLDataType, format: .binary, data: data) + } +} + +extension Double: PostgreSQLDataCustomConvertible { } +extension Float: PostgreSQLDataCustomConvertible { } + +extension Data { + /// Converts this data to a floating-point number. + internal func makeFloatingPoint(_ type: F.Type = F.self) -> F where F: FloatingPoint { + return Data(reversed()).unsafeCast() + } +} + + +extension FloatingPoint { + /// Big-endian bytes for this floating-point number. + internal var data: Data { + var bytes = [UInt8](repeating: 0, count: MemoryLayout.size) + var copy = self + memcpy(&bytes, ©, bytes.count) + return Data(bytes.reversed()) + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+Bool.swift b/Sources/PostgreSQL/Data/PostgreSQLData+Bool.swift new file mode 100644 index 00000000..58b1ae7d --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+Bool.swift @@ -0,0 +1,42 @@ +import Foundation + +extension Bool: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .bool } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._bool } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Bool { + guard let value = data.data else { + throw PostgreSQLError(identifier: "bool", reason: "Could not decode String from `null` data.", source: .capture()) + } + guard value.count == 1 else { + throw PostgreSQLError(identifier: "bool", reason: "Could not decode Bool from value: \(value)", source: .capture()) + } + switch data.format { + case .text: + switch value[0] { + case .t: return true + case .f: return false + default: throw PostgreSQLError(identifier: "bool", reason: "Could not decode Bool from text: \(value)", source: .capture()) + } + case .binary: + switch value[0] { + case 1: return true + case 0: return false + default: throw PostgreSQLError(identifier: "bool", reason: "Could not decode Bool from binary: \(value)", source: .capture()) + } + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: .bool, format: .binary, data: self ? _true : _false) + } +} + +private let _true = Data([0x01]) +private let _false = Data([0x00]) + diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+Data.swift b/Sources/PostgreSQL/Data/PostgreSQLData+Data.swift new file mode 100644 index 00000000..369cc2e8 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+Data.swift @@ -0,0 +1,49 @@ +import Foundation + +extension Data: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .bytea } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._bytea } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Data { + guard let value = data.data else { + throw PostgreSQLError(identifier: "data", reason: "Could not decode Data from `null` data.", source: .capture()) + } + + switch data.type { + case .bytea: + switch data.format { + case .text: return try Data(hexString: value[2...].makeString()) + case .binary: return value + } + default: throw PostgreSQLError(identifier: "data", reason: "Could not decode Data from data type: \(data.type)", source: .capture()) + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: .bytea, format: .binary, data: self) + } +} + +extension Data { + /// Initialize data from a hex string. + internal init(hexString: String) { + var data = Data() + + var gen = hexString.makeIterator() + while let c1 = gen.next(), let c2 = gen.next() { + let s = String([c1, c2]) + guard let d = UInt8(s, radix: 16) else { + break + } + + data.append(d) + } + + self.init(data) + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+Date.swift b/Sources/PostgreSQL/Data/PostgreSQLData+Date.swift new file mode 100644 index 00000000..35003b66 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+Date.swift @@ -0,0 +1,63 @@ +import Foundation + +extension Date: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .timestamp } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._timestamp } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Date { + guard let value = data.data else { + throw PostgreSQLError(identifier: "data", reason: "Could not decode String from `null` data.", source: .capture()) + } + switch data.format { + case .text: + switch data.type { + case .timestamp: return try value.makeString().parseDate(format: "yyyy-MM-dd HH:mm:ss") + case .date: return try value.makeString().parseDate(format: "yyyy-MM-dd") + case .time: return try value.makeString().parseDate(format: "HH:mm:ss") + default: throw PostgreSQLError(identifier: "date", reason: "Could not parse Date from text data type: \(data.type).", source: .capture()) + } + case .binary: + switch data.type { + case .timestamp, .time: + let microseconds = try value.makeFixedWidthInteger(Int64.self) + let seconds = microseconds / _microsecondsPerSecond + return Date(timeInterval: Double(seconds), since: _psqlDateStart) + case .date: + let days = try value.makeFixedWidthInteger(Int32.self) + let seconds = days * _secondsInDay + return Date(timeInterval: Double(seconds), since: _psqlDateStart) + default: throw PostgreSQLError(identifier: "date", reason: "Could not parse Date from binary data type: \(data.type).", source: .capture()) + } + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: .timestamp, format: .text, data: Data(description.utf8)) + } +} + +private let _microsecondsPerSecond: Int64 = 1_000_000 +private let _secondsInDay: Int32 = 24 * 60 * 60 +private let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) // values are stored as seconds before or after midnight 2000-01-01 + +extension String { + /// Parses a Date from this string with the supplied date format. + fileprivate func parseDate(format: String) throws -> Date { + let formatter = DateFormatter() + if contains(".") { + formatter.dateFormat = format + ".SSSSSS" + } else { + formatter.dateFormat = format + } + formatter.timeZone = TimeZone(secondsFromGMT: 0) + guard let date = formatter.date(from: self) else { + throw PostgreSQLError(identifier: "date", reason: "Malformed date: \(self)", source: .capture()) + } + return date + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+FixedWidthInteger.swift b/Sources/PostgreSQL/Data/PostgreSQLData+FixedWidthInteger.swift new file mode 100644 index 00000000..b892df3f --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+FixedWidthInteger.swift @@ -0,0 +1,104 @@ +import Foundation + +extension FixedWidthInteger { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { + switch Self.bitWidth { + case 8: return .char + case 16: return .int2 + case 32: return .int4 + case 64: return .int8 + default: fatalError("Integer bit width not supported: \(Self.bitWidth)") + } + } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { + switch Self.bitWidth { + case 8: return ._char + case 16: return ._int2 + case 32: return ._int4 + case 64: return ._int8 + default: fatalError("Integer bit width not supported: \(Self.bitWidth)") + } + } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + guard let value = data.data else { + throw PostgreSQLError(identifier: "fixedWidthInteger", reason: "Could not decode \(Self.self) from `null` data.", source: .capture()) + } + switch data.format { + case .binary: + switch data.type { + case .char: return try safeCast(value.makeFixedWidthInteger(Int8.self)) + case .int2: return try safeCast(value.makeFixedWidthInteger(Int16.self)) + case .int4: return try safeCast(value.makeFixedWidthInteger(Int32.self)) + case .int8: return try safeCast(value.makeFixedWidthInteger(Int64.self)) + default: throw DecodingError.typeMismatch(Self.self, .init(codingPath: [], debugDescription: "")) + } + case .text: + let string = try value.makeString() + guard let converted = Self(string) else { + throw PostgreSQLError(identifier: "fixedWidthInteger", reason: "Could not decode \(Self.self) from text: \(string).", source: .capture()) + } + return converted + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: Self.postgreSQLDataType, format: .binary, data: self.data) + } + + + /// Safely casts one `FixedWidthInteger` to another. + internal static func safeCast(_ value: V, to type: I.Type = I.self) throws -> I where V: FixedWidthInteger, I: FixedWidthInteger { + if let existing = value as? I { + return existing + } + + guard I.bitWidth >= V.bitWidth else { + throw DecodingError.typeMismatch(type, .init(codingPath: [], debugDescription: "Bit width too wide: \(I.bitWidth) < \(V.bitWidth)")) + } + guard value <= I.max else { + throw DecodingError.typeMismatch(type, .init(codingPath: [], debugDescription: "Value too large: \(value) > \(I.max)")) + } + guard value >= I.min else { + throw DecodingError.typeMismatch(type, .init(codingPath: [], debugDescription: "Value too small: \(value) < \(I.min)")) + } + return I(value) + } +} + +extension Int: PostgreSQLDataCustomConvertible {} +extension Int8: PostgreSQLDataCustomConvertible {} +extension Int16: PostgreSQLDataCustomConvertible {} +extension Int32: PostgreSQLDataCustomConvertible {} +extension Int64: PostgreSQLDataCustomConvertible {} + +extension UInt: PostgreSQLDataCustomConvertible {} +extension UInt8: PostgreSQLDataCustomConvertible {} +extension UInt16: PostgreSQLDataCustomConvertible {} +extension UInt32: PostgreSQLDataCustomConvertible {} +extension UInt64: PostgreSQLDataCustomConvertible {} + +extension Data { + /// Converts this data to a fixed-width integer. + internal func makeFixedWidthInteger(_ type: I.Type = I.self) throws -> I where I: FixedWidthInteger { + guard count >= (I.bitWidth / 8) else { + throw PostgreSQLError(identifier: "fixedWidthData", reason: "Not enough bytes to decode \(I.self): \(count)/\(I.bitWidth / 8)", source: .capture()) + } + return unsafeCast(to: I.self).bigEndian + } +} + +extension FixedWidthInteger { + /// Big-endian bytes for this integer. + internal var data: Data { + var bytes = [UInt8](repeating: 0, count: Self.bitWidth / 8) + var intNetwork = bigEndian + memcpy(&bytes, &intNetwork, bytes.count) + return Data(bytes) + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+Optional.swift b/Sources/PostgreSQL/Data/PostgreSQLData+Optional.swift new file mode 100644 index 00000000..90f0fe56 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+Optional.swift @@ -0,0 +1,33 @@ +import Async +import Foundation + +extension OptionalType where Self.WrappedType: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + let wrapped = try WrappedType.convertFromPostgreSQLData(data) + return Self.makeOptionalType(wrapped) + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + if let wrapped = self.wrapped { + return try wrapped.convertToPostgreSQLData() + } else { + return PostgreSQLData(type: .void, format: .binary, data: nil) + } + } +} + +extension Optional: PostgreSQLDataCustomConvertible where Wrapped: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { + return Wrapped.postgreSQLDataType + } + + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { + return Wrapped.postgreSQLDataArrayType + } +} + diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+Point.swift b/Sources/PostgreSQL/Data/PostgreSQLData+Point.swift new file mode 100644 index 00000000..67b1b0cb --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+Point.swift @@ -0,0 +1,68 @@ +import Foundation + +/// A 2-dimenstional (double[2]) point. +public struct PostgreSQLPoint: Codable { + /// The point's x coordinate. + public var x: Double + + /// The point's y coordinate. + public var y: Double + + /// Create a new `Point` + public init(x: Double, y: Double) { + self.x = x + self.y = y + } +} + +extension PostgreSQLPoint: CustomStringConvertible { + /// See `CustomStringConvertible.description` + public var description: String { + return "(\(x),\(y))" + } +} + +extension PostgreSQLPoint: Equatable { + /// See `Equatable.==` + public static func ==(lhs: PostgreSQLPoint, rhs: PostgreSQLPoint) -> Bool { + return lhs.x == rhs.x && lhs.y == rhs.y + } +} + +extension PostgreSQLPoint: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .point } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._point } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> PostgreSQLPoint { + guard let value = data.data else { + throw PostgreSQLError(identifier: "data", reason: "Could not decode Point from `null` data.", source: .capture()) + } + switch data.type { + case .point: + switch data.format { + case .text: + let string = try value.makeString() + let parts = string.split(separator: ",") + var x = parts[0] + var y = parts[1] + assert(x.popFirst()! == "(") + assert(y.popLast()! == ")") + return .init(x: Double(x)!, y: Double(y)!) + case .binary: + let x = value[0..<8] + let y = value[8..<16] + return .init(x: x.makeFloatingPoint(), y: y.makeFloatingPoint()) + } + default: throw PostgreSQLError(identifier: "point", reason: "Could not decode Point from data type: \(data.type)", source: .capture()) + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return PostgreSQLData(type: .point, format: .binary, data: x.data + y.data) + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+String.swift b/Sources/PostgreSQL/Data/PostgreSQLData+String.swift new file mode 100644 index 00000000..fff98dc8 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+String.swift @@ -0,0 +1,114 @@ +import Foundation + +extension String: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .text } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._text } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> String { + guard let value = data.data else { + throw PostgreSQLError(identifier: "string", reason: "Could not decode String from `null` data.", source: .capture()) + } + switch data.format { + case .text: + guard let string = String(data: value, encoding: .utf8) else { + throw PostgreSQLError(identifier: "string", reason: "Non-UTF8 string: \(value.hexDebug).", source: .capture()) + } + return string + case .binary: + switch data.type { + case .text, .name, .varchar, .bpchar: + guard let string = String(data: value, encoding: .utf8) else { + throw PostgreSQLError(identifier: "string", reason: "Non-UTF8 string: \(value.hexDebug).", source: .capture()) + } + return string + case .point: + let point = try PostgreSQLPoint.convertFromPostgreSQLData(data) + return point.description + case .uuid: + return try UUID.convertFromPostgreSQLData(data).uuidString + case .numeric: + /// create mutable value since we will be using `.extract` which advances the buffer's view + var value = value + + /// grab the numeric metadata from the beginning of the array + let metadata = value.extract(PostgreSQLNumericMetadata.self) + + var integer = "" + var fractional = "" + for offset in 0.. PostgreSQLData { + return PostgreSQLData(type: .text, format: .binary, data: Data(utf8)) + } +} + +/// Represents the meta information preceeding a numeric value. +/// Note: all values must be accessed adding `.bigEndian` +struct PostgreSQLNumericMetadata { + /// The number of digits after this metadata + var ndigits: Int16 + /// How many of the digits are before the decimal point (always add 1) + var weight: Int16 + /// If 1, this number is negative. Otherwise, positive. + var sign: Int16 + /// The number of sig digits after the decimal place (get rid of trailing 0s) + var dscale: Int16 +} + +extension Data { + /// Convert the row's data into a string, throwing if invalid encoding. + internal func makeString(encoding: String.Encoding = .utf8) throws -> String { + guard let string = String(data: self, encoding: encoding) else { + throw PostgreSQLError(identifier: "utf8String", reason: "Unexpected non-UTF8 string: \(hexDebug).", source: .capture()) + } + + return string + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData+UUID.swift b/Sources/PostgreSQL/Data/PostgreSQLData+UUID.swift new file mode 100644 index 00000000..9da7643b --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData+UUID.swift @@ -0,0 +1,39 @@ +import Foundation + +extension UUID: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .uuid } + + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._uuid } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> UUID { + guard let value = data.data else { + throw PostgreSQLError(identifier: "data", reason: "Could not decode UUID from `null` data.", source: .capture()) + } + switch data.type { + case .uuid: + switch data.format { + case .text: + let string = try value.makeString() + guard let uuid = UUID(uuidString: string) else { + throw PostgreSQLError(identifier: "uuid", reason: "Could not decode UUID from string: \(string)", source: .capture()) + } + return uuid + case .binary: return UUID(uuid: value.unsafeCast()) + } + default: throw PostgreSQLError(identifier: "uuid", reason: "Could not decode UUID from data type: \(data.type)", source: .capture()) + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + var uuid = self.uuid + let size = MemoryLayout.size(ofValue: uuid) + return PostgreSQLData(type: .uuid, format: .binary, data: withUnsafePointer(to: &uuid) { + Data(bytes: $0, count: size) + }) + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLData.swift b/Sources/PostgreSQL/Data/PostgreSQLData.swift new file mode 100644 index 00000000..fd023bd5 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLData.swift @@ -0,0 +1,35 @@ +import Foundation + +/// Supported `PostgreSQLData` data types. +public struct PostgreSQLData { + /// The data's type. + public var type: PostgreSQLDataType + + /// The data's format. + public var format: PostgreSQLFormatCode + + /// The actual data. + public var data: Data? + + public init(type: PostgreSQLDataType, format: PostgreSQLFormatCode = .binary, data: Data? = nil) { + self.type = type + self.format = format + self.data = data + } +} + +extension PostgreSQLData: CustomStringConvertible { + /// See `CustomStringConvertible.description` + public var description: String { + return "\(type) (\(format)) \(data?.hexDebug ?? "null")" + } +} + +/// MARK: Equatable + +extension PostgreSQLData: Equatable { + /// See Equatable.== + public static func ==(lhs: PostgreSQLData, rhs: PostgreSQLData) -> Bool { + return lhs.format == rhs.format && lhs.type == rhs.type && lhs.data == rhs.data + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLDataCustomConvertible.swift b/Sources/PostgreSQL/Data/PostgreSQLDataCustomConvertible.swift new file mode 100644 index 00000000..469e0716 --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLDataCustomConvertible.swift @@ -0,0 +1,65 @@ +/// Capable of being converted to/from `PostgreSQLData` +public protocol PostgreSQLDataCustomConvertible { + /// This type's preferred data type. + static var postgreSQLDataType: PostgreSQLDataType { get } + + /// This type's preferred array type. + static var postgreSQLDataArrayType: PostgreSQLDataType { get } + + /// Creates a `Self` from the supplied `PostgreSQLData` + static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self + + /// Converts `Self` to a `PostgreSQLData` + func convertToPostgreSQLData() throws -> PostgreSQLData +} + +extension PostgreSQLData { + /// Gets a `String` from the supplied path or throws a decoding error. + public func decode(_ type: T.Type) throws -> T where T: PostgreSQLDataCustomConvertible { + return try T.convertFromPostgreSQLData(self) + } +} + +extension PostgreSQLData: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .void } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return .void } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> PostgreSQLData { + return data + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return self + } +} + +extension RawRepresentable where RawValue: PostgreSQLDataCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + static var postgreSQLDataType: PostgreSQLDataType { + return Self.postgreSQLDataType + } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + static var postgreSQLDataArrayType: PostgreSQLDataType { + return Self.postgreSQLDataArrayType + } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + let aRawValue = try RawValue.convertFromPostgreSQLData(data) + guard let enumValue = Self(rawValue: aRawValue) else { + throw PostgreSQLError(identifier: "invalidRawValue", reason: "Unable to decode RawRepresentable from the database value.", source: .capture()) + } + return enumValue + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + func convertToPostgreSQLData() throws -> PostgreSQLData { + return try self.rawValue.convertToPostgreSQLData() + } +} diff --git a/Sources/PostgreSQL/Data/PostgreSQLJSONCustomConvertible.swift b/Sources/PostgreSQL/Data/PostgreSQLJSONCustomConvertible.swift new file mode 100644 index 00000000..6c1a3dae --- /dev/null +++ b/Sources/PostgreSQL/Data/PostgreSQLJSONCustomConvertible.swift @@ -0,0 +1,38 @@ +import COperatingSystem +import Foundation + +/// Representable by a `JSONB` column on the PostgreSQL database. +public protocol PostgreSQLJSONCustomConvertible: PostgreSQLDataCustomConvertible, Codable { } + +extension PostgreSQLJSONCustomConvertible { + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataType` + public static var postgreSQLDataType: PostgreSQLDataType { return .jsonb } + + /// See `PostgreSQLDataCustomConvertible.postgreSQLDataArrayType` + public static var postgreSQLDataArrayType: PostgreSQLDataType { return ._jsonb } + + /// See `PostgreSQLDataCustomConvertible.convertFromPostgreSQLData(_:)` + public static func convertFromPostgreSQLData(_ data: PostgreSQLData) throws -> Self { + guard let value = data.data else { + throw PostgreSQLError(identifier: "data", reason: "Unable to decode PostgreSQL JSON from `null` data.", source: .capture()) + } + + switch data.type { + case .jsonb: + switch data.format { + case .text: return try JSONDecoder().decode(Self.self, from: value) + case .binary: + assert(value[0] == 0x01) + return try JSONDecoder().decode(Self.self, from: value[1...]) + } + default: throw PostgreSQLError(identifier: "json", reason: "Could not decode \(Self.self) from data type: \(data.type).", source: .capture()) + } + } + + /// See `PostgreSQLDataCustomConvertible.convertToPostgreSQLData()` + public func convertToPostgreSQLData() throws -> PostgreSQLData { + return try PostgreSQLData(type: .jsonb, format: .text, data: JSONEncoder().encode(self)) + } +} + +extension Dictionary: PostgreSQLJSONCustomConvertible where Key: Codable, Value: Codable { } diff --git a/Sources/PostgreSQL/DataType/PostgreSQLDataType.swift b/Sources/PostgreSQL/DataType/PostgreSQLDataType.swift new file mode 100644 index 00000000..c5bb2524 --- /dev/null +++ b/Sources/PostgreSQL/DataType/PostgreSQLDataType.swift @@ -0,0 +1,120 @@ +import Foundation + +/// The data type's raw object ID. +/// Use `select * from pg_type where oid = ;` to lookup more information. +public struct PostgreSQLDataType: Codable, Equatable, ExpressibleByIntegerLiteral { + /// Recognized types + public static let bool = PostgreSQLDataType(16) + public static let bytea = PostgreSQLDataType(17) + public static let char = PostgreSQLDataType(18) + public static let name = PostgreSQLDataType(19) + public static let int8 = PostgreSQLDataType(20) + public static let int2 = PostgreSQLDataType(21) + public static let int4 = PostgreSQLDataType(23) + public static let regproc = PostgreSQLDataType(24) + public static let text = PostgreSQLDataType(25) + public static let oid = PostgreSQLDataType(26) + public static let json = PostgreSQLDataType(114) + public static let pg_node_tree = PostgreSQLDataType(194) + public static let point = PostgreSQLDataType(600) + public static let float4 = PostgreSQLDataType(700) + public static let float8 = PostgreSQLDataType(701) + public static let _bool = PostgreSQLDataType(1000) + public static let _bytea = PostgreSQLDataType(1001) + public static let _char = PostgreSQLDataType(1002) + public static let _name = PostgreSQLDataType(1003) + public static let _int2 = PostgreSQLDataType(1005) + public static let _int4 = PostgreSQLDataType(1007) + public static let _text = PostgreSQLDataType(1009) + public static let _int8 = PostgreSQLDataType(1016) + public static let _point = PostgreSQLDataType(1017) + public static let _float4 = PostgreSQLDataType(1021) + public static let _float8 = PostgreSQLDataType(1022) + public static let _aclitem = PostgreSQLDataType(1034) + public static let bpchar = PostgreSQLDataType(1042) + public static let varchar = PostgreSQLDataType(1043) + public static let date = PostgreSQLDataType(1082) + public static let time = PostgreSQLDataType(1083) + public static let timestamp = PostgreSQLDataType(1114) + public static let _timestamp = PostgreSQLDataType(1115) + public static let numeric = PostgreSQLDataType(1700) + public static let void = PostgreSQLDataType(2278) + public static let uuid = PostgreSQLDataType(2950) + public static let _uuid = PostgreSQLDataType(2951) + public static let jsonb = PostgreSQLDataType(3802) + public static let _jsonb = PostgreSQLDataType(3807) + + /// See `Equatable.==` + public static func ==(lhs: PostgreSQLDataType, rhs: PostgreSQLDataType) -> Bool { + return lhs.raw == rhs.raw + } + + /// The raw data type code recognized by PostgreSQL. + public var raw: Int32 + + /// See `ExpressibleByIntegerLiteral.init(integerLiteral:)` + public init(integerLiteral value: Int32) { + self.init(value) + } + + /// Creates a new `PostgreSQLDataType` + public init(_ raw: Int32) { + self.raw = raw + } +} + +extension PostgreSQLDataType { + /// Returns the known SQL name, if one exists. + /// Note: This only supports a limited subset of all PSQL types and is meant for convenience only. + public var knownSQLName: String? { + switch self { + case .bool: return "BOOLEAN" + case .bytea: return "BYTEA" + case .char: return "CHAR" + case .name: return "NAME" + case .int8: return "BIGINT" + case .int2: return "SMALLINT" + case .int4: return "INTEGER" + case .regproc: return "REGPROC" + case .text: return "TEXT" + case .oid: return "OID" + case .json: return "JSON" + case .pg_node_tree: return "PGNODETREE" + case .point: return "POINT" + case .float4: return "REAL" + case .float8: return "DOUBLE PRECISION" + case ._bool: return "BOOLEAN[]" + case ._bytea: return "BYTEA[]" + case ._char: return "CHAR[]" + case ._name: return "NAME[]" + case ._int2: return "SMALLINT[]" + case ._int4: return "INTEGER[]" + case ._text: return "TEXT[]" + case ._int8: return "BIGINT[]" + case ._point: return "POINT[]" + case ._float4: return "REAL[]" + case ._float8: return "DOUBLE PRECISION[]" + case ._aclitem: return "ACLITEM[]" + case .bpchar: return "BPCHAR" + case .varchar: return "VARCHAR" + case .date: return "DATE" + case .time: return "TIME" + case .timestamp: return "TIMESTAMP" + case ._timestamp: return "TIMESTAMP[]" + case .numeric: return "NUMERIC" + case .void: return "VOID" + case .uuid: return "UUID" + case ._uuid: return "UUID[]" + case .jsonb: return "JSONB" + case ._jsonb: return "JSONB[]" + default: return nil + } + } +} + +extension PostgreSQLDataType: CustomStringConvertible { + /// See `CustomStringConvertible.description` + public var description: String { + return knownSQLName ?? "UNKNOWN \(raw)" + } +} diff --git a/Sources/PostgreSQL/DataType/PostgreSQLFormatCode.swift b/Sources/PostgreSQL/DataType/PostgreSQLFormatCode.swift new file mode 100644 index 00000000..fd63317a --- /dev/null +++ b/Sources/PostgreSQL/DataType/PostgreSQLFormatCode.swift @@ -0,0 +1,38 @@ +/// The format code being used for the field. +/// Currently will be zero (text) or one (binary). +/// In a RowDescription returned from the statement variant of Describe, +/// the format code is not yet known and will always be zero. +public enum PostgreSQLFormatCode: Int16, Codable { + case text = 0 + case binary = 1 +} + +public struct PostgreSQLResultFormat { + /// The format codes + internal let formatCodeFactory: ([PostgreSQLDataType]) -> [PostgreSQLFormatCode] + + /// Dynamically choose result format based on data type. + public static func dynamic(_ callback: @escaping (PostgreSQLDataType) -> PostgreSQLFormatCode) -> PostgreSQLResultFormat { + return .init { return $0.map { callback($0) } } + } + + /// Request all of the results in a specific format. + public static func all(_ code: PostgreSQLFormatCode) -> PostgreSQLResultFormat { + return .init { _ in return [code] } + } + + /// Request all of the results in a specific format. + public static func text() -> PostgreSQLResultFormat { + return .all(.text) + } + + /// Request all of the results in a specific format. + public static func binary() -> PostgreSQLResultFormat { + return .all(.binary) + } + + /// Let the server decide the formatting options. + public static func notSpecified() -> PostgreSQLResultFormat { + return .init { _ in return [] } + } +} diff --git a/Sources/PostgreSQL/Database/PostgreSQLDatabase.swift b/Sources/PostgreSQL/Database/PostgreSQLDatabase.swift new file mode 100644 index 00000000..295dc634 --- /dev/null +++ b/Sources/PostgreSQL/Database/PostgreSQLDatabase.swift @@ -0,0 +1,42 @@ +import Async + +/// Creates connections to an identified PostgreSQL database. +public final class PostgreSQLDatabase: Database { + /// This database's configuration. + public let config: PostgreSQLDatabaseConfig + + /// If non-nil, will log queries. + public var logger: PostgreSQLLogger? + + /// Creates a new `PostgreSQLDatabase`. + public init(config: PostgreSQLDatabaseConfig) { + self.config = config + } + + /// See `Database.makeConnection()` + public func makeConnection(on worker: Worker) -> Future { + do { + let client = try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, on: worker) { _, error in + print("[PostgreSQL] \(error)") + } + client.logger = logger + return client.authenticate( + username: config.username, + database: config.database, + password: config.password + ).transform(to: client) + } catch { + return Future(error: error) + } + } +} + +/// A connection created by a `PostgreSQLDatabase`. +extension PostgreSQLConnection: DatabaseConnection { } + +extension DatabaseIdentifier { + /// Default identifier for `PostgreSQLDatabase`. + public static var psql: DatabaseIdentifier { + return .init("psql") + } +} diff --git a/Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift b/Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift new file mode 100644 index 00000000..8f302388 --- /dev/null +++ b/Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift @@ -0,0 +1,32 @@ +/// Config options for a `PostgreSQLConnection` +public struct PostgreSQLDatabaseConfig { + /// Creates a `PostgreSQLDatabaseConfig` with default settings. + public static func `default`() -> PostgreSQLDatabaseConfig { + return .init(hostname: "localhost", port: 5432, username: "postgres") + } + + /// Destination hostname. + public let hostname: String + + /// Destination port. + public let port: UInt16 + + /// Username to authenticate. + public let username: String + + /// Optional database name to use during authentication. + /// Defaults to the username. + public let database: String? + + /// Optional password to use for authentication. + public let password: String? + + /// Creates a new `PostgreSQLDatabaseConfig`. + public init(hostname: String, port: UInt16, username: String, database: String? = nil, password: String? = nil) { + self.hostname = hostname + self.port = port + self.username = username + self.database = database + self.password = password + } +} diff --git a/Sources/PostgreSQL/Database/PostgreSQLLogger.swift b/Sources/PostgreSQL/Database/PostgreSQLLogger.swift new file mode 100644 index 00000000..c50616fc --- /dev/null +++ b/Sources/PostgreSQL/Database/PostgreSQLLogger.swift @@ -0,0 +1,5 @@ +/// Capable of logging PostgreSQL queries. +public protocol PostgreSQLLogger { + /// Logs the query and supplied parameters. + func log(query: String, parameters: [PostgreSQLData]) +} diff --git a/Sources/PostgreSQL/Exports.swift b/Sources/PostgreSQL/Exports.swift new file mode 100644 index 00000000..7cc13284 --- /dev/null +++ b/Sources/PostgreSQL/Exports.swift @@ -0,0 +1 @@ +@_exported import DatabaseKit diff --git a/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageDecoder.swift b/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageDecoder.swift new file mode 100644 index 00000000..c6b95fb2 --- /dev/null +++ b/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageDecoder.swift @@ -0,0 +1,285 @@ +import Bits +import Foundation + +/// Non-decoder wrapper for `_PostgreSQLMessageDecoder`. +final class PostgreSQLMessageDecoder { + /// Create a new `PostgreSQLMessageDecoder` + init() {} + + /// Encodes a `PostgreSQLMessage` to `Data`. + func decode(_ data: Data) throws -> (PostgreSQLMessage, Int)? { + let decoder = _PostgreSQLMessageDecoder(data: data) + guard decoder.data.count >= 1 else { + return nil + } + + let type = try decoder.decode(Byte.self) + guard try decoder.verifyLength() else { + return nil + } + + let message: PostgreSQLMessage + switch type { + case .E: message = try .error(decoder.decode()) + case .N: message = try .notice(decoder.decode()) + case .R: message = try .authenticationRequest(decoder.decode()) + case .S: message = try .parameterStatus(decoder.decode()) + case .K: message = try .backendKeyData(decoder.decode()) + case .Z: message = try .readyForQuery(decoder.decode()) + case .T: message = try .rowDescription(decoder.decode()) + case .D: message = try .dataRow(decoder.decode()) + case .C: message = try .close(decoder.decode()) + case .one: message = .parseComplete + case .two: message = .bindComplete + case .n: message = .noData + case .t: message = try .parameterDescription(decoder.decode()) + default: + let string = String(bytes: [type], encoding: .ascii) ?? "n/a" + throw PostgreSQLError( + identifier: "decoder", + reason: "Unrecognized message type: \(string) (\(type)", + possibleCauses: ["Connected to non-PostgreSQL database"], + suggestedFixes: ["Connect to PostgreSQL database"], + source: .capture() + ) + } + return (message, decoder.data.count) + } +} + +// MARK: Decoder / Single + +fileprivate final class _PostgreSQLMessageDecoder: Decoder, SingleValueDecodingContainer { + /// See Decoder.codingPath + var codingPath: [CodingKey] + + /// See Decoder.userInfo + var userInfo: [CodingUserInfoKey: Any] + + /// The data being decoded. + var data: Data + + /// Creates a new internal `_PostgreSQLMessageDecoder`. + init(data: Data) { + self.codingPath = [] + self.userInfo = [:] + self.data = data + } + + /// Extracts and verifies the data length. + func verifyLength() throws -> Bool { + guard let length = try extractLength() else { + return false + } + + guard data.count + MemoryLayout.size >= length else { + return false + } + + return true + } + + /// Extracts an Int32 length, returning `nil` + /// if it doesn't exist. + func extractLength() throws -> Int32? { + guard data.count >= 4 else { + // need length + return nil + } + return try decode(Int32.self) + } + + /// See Encoder.singleValueContainer + func singleValueContainer() throws -> SingleValueDecodingContainer { + return self + } + + /// See SingleValueDecodingContainer.decode + func decode(_ type: UInt8.Type) throws -> UInt8 { + return self.data.unsafePopFirst() + } + + /// See SingleValueDecodingContainer.decode + func decode(_ type: Int16.Type) throws -> Int16 { + return try decode(fixedWidthInteger: Int16.self) + } + + /// See SingleValueDecodingContainer.decode + func decode(_ type: Int32.Type) throws -> Int32 { + return try decode(fixedWidthInteger: Int32.self) + } + + /// Decodes a fixed width integer. + func decode(fixedWidthInteger type: B.Type) throws -> B where B: FixedWidthInteger { + return data.extract(B.self).bigEndian + } + + /// See SingleValueDecodingContainer.decode + func decode(_ type: String.Type) throws -> String { + var bytes: [UInt8] = [] + parse: while true { + let byte = self.data.unsafePopFirst() + switch byte { + case 0: break parse // c style strings + default: bytes.append(byte) + } + } + let data = Data(bytes: bytes) + guard let string = String(data: data, encoding: .utf8) else { + throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: non-UTF8 string", source: .capture()) + } + return string + } + + /// See SingleValueDecodingContainer.decode + func decode(_ type: T.Type = T.self) throws -> T where T: Decodable { + if T.self == Data.self { + let count = try Int(decode(Int32.self)) + switch count { + case 0: return Data() as! T + case 1...: + let sub: Data = data.subdata(in: data.startIndex.. Bool { + guard data.count >= 4 else { + return false + } + + /// if Int32 decode == -1, then this should be decoding `Data?.none` + let count = data.withUnsafeBytes { (pointer: UnsafePointer) -> Int32 in + return pointer.withMemoryRebound(to: Int32.self, capacity: 1) { (pointer: UnsafePointer) -> Int32 in + return pointer.pointee.bigEndian + } + } + switch count { + case -1: + data = data.advanced(by: MemoryLayout.size) + return true + default: return false + } + } + + /// See Decoder.container + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key : CodingKey { + let container = _PostgreSQLMessageKeyedDecoder(decoder: self) + return KeyedDecodingContainer(container) + } + + /// See Decoder.unkeyedContainer + func unkeyedContainer() throws -> UnkeyedDecodingContainer { + return _PostgreSQLMessageUnkeyedDecoder(decoder: self) + } + + // Unsupported + func decode(_ type: Bool.Type) throws -> Bool { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: Int.Type) throws -> Int { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: Int8.Type) throws -> Int8 { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: Int64.Type) throws -> Int64 { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: UInt.Type) throws -> UInt { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: UInt16.Type) throws -> UInt16 { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: UInt32.Type) throws -> UInt32 { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: UInt64.Type) throws -> UInt64 { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: Float.Type) throws -> Float { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } + func decode(_ type: Double.Type) throws -> Double { throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decode type: \(type)", source: .capture()) } +} + +// MARK: Keyed + +fileprivate final class _PostgreSQLMessageKeyedDecoder: KeyedDecodingContainerProtocol where K: CodingKey { + typealias Key = K + var codingPath: [CodingKey] + var allKeys: [K] + var decoder: _PostgreSQLMessageDecoder + + /// Creates a new internal `_PostgreSQLMessageKeyedDecoder` + init(decoder: _PostgreSQLMessageDecoder) { + self.codingPath = [] + self.allKeys = [] + self.decoder = decoder + } + + // Map decode for key to decoder + + func contains(_ key: K) -> Bool { return true } + func decode(_ type: Bool.Type, forKey key: K) throws -> Bool { return try decoder.decode(Bool.self) } + func decode(_ type: Int.Type, forKey key: K) throws -> Int { return try decoder.decode(Int.self) } + func decode(_ type: Int8.Type, forKey key: K) throws -> Int8 { return try decoder.decode(Int8.self) } + func decode(_ type: Int16.Type, forKey key: K) throws -> Int16 { return try decoder.decode(Int16.self) } + func decode(_ type: Int32.Type, forKey key: K) throws -> Int32 { return try decoder.decode(Int32.self) } + func decode(_ type: Int64.Type, forKey key: K) throws -> Int64 { return try decoder.decode(Int64.self) } + func decode(_ type: UInt.Type, forKey key: K) throws -> UInt { return try decoder.decode(UInt.self) } + func decode(_ type: UInt8.Type, forKey key: K) throws -> UInt8 { return try decoder.decode(UInt8.self) } + func decode(_ type: UInt16.Type, forKey key: K) throws -> UInt16 { return try decoder.decode(UInt16.self) } + func decode(_ type: UInt32.Type, forKey key: K) throws -> UInt32 { return try decoder.decode(UInt32.self) } + func decode(_ type: UInt64.Type, forKey key: K) throws -> UInt64 { return try decoder.decode(UInt64.self) } + func decode(_ type: Float.Type, forKey key: K) throws -> Float { return try decoder.decode(Float.self) } + func decode(_ type: Double.Type, forKey key: K) throws -> Double { return try decoder.decode(Double.self) } + func decode(_ type: String.Type, forKey key: K) throws -> String { return try decoder.decode(String.self) } + func decode(_ type: T.Type, forKey key: K) throws -> T where T : Decodable { return try decoder.decode(T.self) } + func superDecoder() throws -> Decoder { return decoder } + func superDecoder(forKey key: K) throws -> Decoder { return decoder } + func decodeNil(forKey key: K) throws -> Bool { return decoder.decodeNil() } + + // Unsupported + + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: K) throws -> KeyedDecodingContainer where NestedKey : CodingKey { + let container = _PostgreSQLMessageKeyedDecoder(decoder: decoder) + return KeyedDecodingContainer(container) + } + + func nestedUnkeyedContainer(forKey key: K) throws -> UnkeyedDecodingContainer { + throw PostgreSQLError(identifier: "decoder", reason: "Unsupported decoding type: nested unkeyed container", source: .capture()) + } +} + +/// MARK: Unkeyed + +fileprivate final class _PostgreSQLMessageUnkeyedDecoder: UnkeyedDecodingContainer { + var count: Int? + var isAtEnd: Bool { + return currentIndex == count + } + var currentIndex: Int + var codingPath: [CodingKey] + var decoder: _PostgreSQLMessageDecoder + + /// Creates a new internal `_PostgreSQLMessageUnkeyedDecoder` + init(decoder: _PostgreSQLMessageDecoder) { + self.codingPath = [] + self.decoder = decoder + self.count = try! Int(decoder.decode(Int16.self)) + currentIndex = 0 + } + + func decode(_ type: Bool.Type) throws -> Bool { currentIndex += 1; return try decoder.decode(Bool.self) } + func decode(_ type: Int.Type) throws -> Int { currentIndex += 1; return try decoder.decode(Int.self) } + func decode(_ type: Int8.Type) throws -> Int8 { currentIndex += 1; return try decoder.decode(Int8.self) } + func decode(_ type: Int16.Type) throws -> Int16 { currentIndex += 1; return try decoder.decode(Int16.self) } + func decode(_ type: Int32.Type) throws -> Int32 { currentIndex += 1; return try decoder.decode(Int32.self) } + func decode(_ type: Int64.Type) throws -> Int64 { currentIndex += 1; return try decoder.decode(Int64.self) } + func decode(_ type: UInt.Type) throws -> UInt { currentIndex += 1; return try decoder.decode(UInt.self) } + func decode(_ type: UInt8.Type) throws -> UInt8 { currentIndex += 1; return try decoder.decode(UInt8.self) } + func decode(_ type: UInt16.Type) throws -> UInt16 { currentIndex += 1; return try decoder.decode(UInt16.self) } + func decode(_ type: UInt32.Type) throws -> UInt32 { currentIndex += 1; return try decoder.decode(UInt32.self) } + func decode(_ type: UInt64.Type) throws -> UInt64 { currentIndex += 1; return try decoder.decode(UInt64.self) } + func decode(_ type: Float.Type) throws -> Float { currentIndex += 1; return try decoder.decode(Float.self) } + func decode(_ type: Double.Type) throws -> Double { currentIndex += 1; return try decoder.decode(Double.self) } + func decode(_ type: String.Type) throws -> String { currentIndex += 1; return try decoder.decode(String.self) } + func decodeNil() throws -> Bool {currentIndex += 1; return decoder.decodeNil() } + func decode(_ type: T.Type) throws -> T where T : Decodable { currentIndex += 1; return try decoder.decode(T.self) } + func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { currentIndex += 1; return _PostgreSQLMessageUnkeyedDecoder(decoder: decoder) } + func superDecoder() throws -> Decoder { return decoder } + func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey : CodingKey { + let container = _PostgreSQLMessageKeyedDecoder(decoder: decoder) + return KeyedDecodingContainer(container) + } +} diff --git a/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageParser.swift b/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageParser.swift new file mode 100644 index 00000000..39be26fc --- /dev/null +++ b/Sources/PostgreSQL/Message+Parse/PostgreSQLMessageParser.swift @@ -0,0 +1,48 @@ +import Async +import Bits +import Foundation + +/// Byte-stream parser for `PostgreSQLMessage` +final class PostgreSQLMessageParser: TranslatingStream { + /// Data being worked on currently. + var buffered: Data + + /// Excess data waiting to be parsed. + var excess: Data? + + /// Creates a new `PostgreSQLMessageParser`. + init() { + buffered = Data() + } + + /// See TranslatingStream.translate + func translate(input context: inout TranslatingStreamInput) throws -> TranslatingStreamOutput { + if let excess = self.excess { + self.excess = nil + return try parse(data: excess) + } else { + guard let input = context.input else { + return .insufficient() + } + return try parse(data: Data(input)) + } + } + + /// Parses the data, setting `excess` or requesting more data if insufficient. + func parse(data: Data) throws -> TranslatingStreamOutput { + let data = buffered + data + guard let (message, remaining) = try PostgreSQLMessageDecoder().decode(data) else { + buffered.append(data) + return .insufficient() + } + + buffered = .init() + if remaining > 0 { + let start = data.count - remaining + excess = data[start.. Data { + let encoder = _PostgreSQLMessageEncoder() + let identifier: Byte? + switch message { + case .startupMessage(let message): + identifier = nil + try message.encode(to: encoder) + case .query(let query): + identifier = .Q + try query.encode(to: encoder) + case .parse(let parseRequest): + identifier = .P + try parseRequest.encode(to: encoder) + case .sync: + identifier = .S + case .bind(let bind): + identifier = .B + try bind.encode(to: encoder) + case .describe(let describe): + identifier = .D + try describe.encode(to: encoder) + case .execute(let execute): + identifier = .E + try execute.encode(to: encoder) + case .password(let password): + identifier = .p + try password.encode(to: encoder) + default: throw PostgreSQLError(identifier: "encoder", reason: "Unsupported encodable type: \(type(of: message))", source: .capture()) + } + encoder.updateSize() + if let prefix = identifier { + return [prefix] + encoder.data + } else { + return encoder.data + } + } +} + +// MARK: Encoder / Single + +/// Internal `Encoder` implementation for the `PostgreSQLMessageEncoder`. +internal final class _PostgreSQLMessageEncoder: Encoder, SingleValueEncodingContainer { + /// See Encoder.codingPath + var codingPath: [CodingKey] + + /// See Encoder.userInfo + var userInfo: [CodingUserInfoKey: Any] + + /// The data currently being encoded + var data: Data + + /// Creates a new internal `_PostgreSQLMessageEncoder` + init() { + self.codingPath = [] + self.userInfo = [:] + /// Start with 4 bytes for the int32 size chunk + self.data = Data([0, 0, 0, 0]) + } + + /// Updates the int32 size chunk in the data. + func updateSize() { + let size = numericCast(data.count) as Int32 + data.withUnsafeMutableBytes { (pointer: UnsafeMutablePointer) in + pointer.pointee = size.bigEndian + } + } + + /// See Encoder.singleValueContainer + func singleValueContainer() -> SingleValueEncodingContainer { + return self + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: String) throws { + let stringData = Data(value.utf8) + self.data.append(stringData + [0]) // c style string + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: Int8) throws { + self.data.append(numericCast(value)) + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: UInt8) throws { + self.data.append(value) + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: Int16) throws { + var value = value.bigEndian + withUnsafeBytes(of: &value) { buffer in + let buffer = buffer.unsafeBaseAddress.assumingMemoryBound(to: UInt8.self) + self.data.append(buffer, count: 2) + } + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: Int32) throws { + var value = value.bigEndian + withUnsafeBytes(of: &value) { buffer in + let buffer = buffer.unsafeBaseAddress.assumingMemoryBound(to: UInt8.self) + self.data.append(buffer, count: 4) + } + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: Int64) throws { + var value = value.bigEndian + withUnsafeBytes(of: &value) { buffer in + let buffer = buffer.unsafeBaseAddress.assumingMemoryBound(to: UInt8.self) + self.data.append(buffer, count: 8) + } + } + + /// See SingleValueEncodingContainer.encode + func encode(_ value: T) throws where T : Encodable { + if T.self == Data.self { + let sub = value as! Data + try encode(Int32(sub.count)) + self.data += sub + } else { + try value.encode(to: self) + } + } + + /// See Encoder.container + func container(keyedBy type: Key.Type) -> KeyedEncodingContainer where Key : CodingKey { + let container = _PostgreSQLMessageKeyedEncoder(encoder: self) + return KeyedEncodingContainer(container) + } + + /// See Encoder.unkeyedContainer + func unkeyedContainer() -> UnkeyedEncodingContainer { + return _PostgreSQLMessageUnkeyedEncoder(encoder: self) + } + + // Unsupported + + func encode(_ value: Int) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: UInt) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: UInt16) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: UInt32) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: UInt64) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: Float) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: Double) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encode(_ value: Bool) throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: \(type(of: value))", source: .capture()) } + func encodeNil() throws { throw PostgreSQLError(identifier: "encoder", reason: "Unsupported type: nil", source: .capture()) } +} + +fileprivate final class _PostgreSQLMessageKeyedEncoder: KeyedEncodingContainerProtocol where K: CodingKey { + typealias Key = K + var codingPath: [CodingKey] + let encoder: _PostgreSQLMessageEncoder + + init(encoder: _PostgreSQLMessageEncoder) { + self.encoder = encoder + self.codingPath = [] + } + + func encodeNil(forKey key: K) throws { try encoder.encodeNil() } + func encode(_ value: Bool, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Int, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Int8, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Int16, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Int32, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Int64, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: UInt, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: UInt8, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: UInt16, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: UInt32, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: UInt64, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Float, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: Double, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: String, forKey key: K) throws { try encoder.encode(value) } + func encode(_ value: T, forKey key: K) throws where T : Encodable { try encoder.encode(value) } + func nestedContainer(keyedBy keyType: NestedKey.Type, forKey key: K) + -> KeyedEncodingContainer where NestedKey: CodingKey { return encoder.container(keyedBy: NestedKey.self) } + func nestedUnkeyedContainer(forKey key: K) -> UnkeyedEncodingContainer { return encoder.unkeyedContainer() } + func superEncoder() -> Encoder { return encoder } + func superEncoder(forKey key: K) -> Encoder { return encoder } + + func encodeIfPresent(_ value: T?, forKey key: K) throws where T : Encodable { + if T.self == Data.self { + if let data = value { + try encoder.encode(data) + } else { + try encoder.encode(Int32(-1)) // indicate nil data + } + } else { + if let value = value { + try encoder.encode(value) + } else { + try encoder.encodeNil() + } + } + } +} + +/// MARK: Unkeyed + +fileprivate final class _PostgreSQLMessageUnkeyedEncoder: UnkeyedEncodingContainer { + var count: Int + var codingPath: [CodingKey] + let encoder: _PostgreSQLMessageEncoder + let countOffset: Int + + init(encoder: _PostgreSQLMessageEncoder) { + self.encoder = encoder + self.codingPath = [] + self.countOffset = encoder.data.count + self.count = 0 + // will hold count + encoder.data.append(Data([0, 0])) + } + + func encodeNil() throws { try encoder.encodeNil() } + func encode(_ value: Bool) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Int) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Int8) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Int16) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Int32) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Int64) throws { count += 1; try encoder.encode(value) } + func encode(_ value: UInt) throws { count += 1; try encoder.encode(value) } + func encode(_ value: UInt8) throws { count += 1; try encoder.encode(value) } + func encode(_ value: UInt16) throws { count += 1; try encoder.encode(value) } + func encode(_ value: UInt32) throws { count += 1; try encoder.encode(value) } + func encode(_ value: UInt64) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Float) throws { count += 1; try encoder.encode(value) } + func encode(_ value: Double) throws { count += 1; try encoder.encode(value) } + func encode(_ value: String) throws { count += 1; try encoder.encode(value) } + func encode(_ value: T) throws where T : Encodable { count += 1; return try encoder.encode(value) } + func nestedContainer(keyedBy keyType: NestedKey.Type) + -> KeyedEncodingContainer where NestedKey: CodingKey { return encoder.container(keyedBy: NestedKey.self) } + func nestedUnkeyedContainer() -> UnkeyedEncodingContainer { return encoder.unkeyedContainer() } + func superEncoder() -> Encoder { return encoder } + + deinit { + let size = numericCast(count) as Int16 + var data = Data([0, 0]) + data.withUnsafeMutableBytes { (pointer: UnsafeMutablePointer) in + pointer.pointee = size.bigEndian + } + encoder.data.replaceSubrange(countOffset..) throws -> TranslatingStreamOutput { + if let excess = self.excess { + self.excess = nil + return serialize(data: excess) + } else { + guard let input = context.input else { + return .insufficient() + } + let data = try PostgreSQLMessageEncoder().encode(input) + return serialize(data: data) + } + } + + /// Serializes data, storing `excess` if it does not fit in the buffer. + func serialize(data: Data) -> TranslatingStreamOutput { + let count = data.copyBytes(to: buffer) + let view = ByteBuffer(start: buffer.baseAddress, count: count) + if data.count > count { + self.excess = data[count.. PostgreSQLData { + return PostgreSQLData(type: dataType, format: format, data: value) + } +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLDescribeRequest.swift b/Sources/PostgreSQL/Message/PostgreSQLDescribeRequest.swift new file mode 100644 index 00000000..4be33efa --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLDescribeRequest.swift @@ -0,0 +1,32 @@ +import Bits + +/* + Describe (F) + Byte1('D') + Identifies the message as a Describe command. + + Int32 + Length of message contents in bytes, including self. + + Byte1 + 'S' to describe a prepared statement; or 'P' to describe a portal. + + String + The name of the prepared statement or portal to describe (an empty string selects the unnamed prepared statement or portal). + + */ + +/// Identifies the message as a Describe command. +struct PostgreSQLDescribeRequest: Encodable { + /// 'S' to describe a prepared statement; or 'P' to describe a portal. + let type: PostgreSQLDescribeType + + /// The name of the prepared statement or portal to describe + /// (an empty string selects the unnamed prepared statement or portal). + var name: String +} + +enum PostgreSQLDescribeType: Byte, Encodable { + case statement = 0x53 // S + case portal = 0x50 // P +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLErrorReponse.swift b/Sources/PostgreSQL/Message/PostgreSQLErrorReponse.swift new file mode 100644 index 00000000..13190fdc --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLErrorReponse.swift @@ -0,0 +1,121 @@ +import Bits +import Debugging + +/// First message sent from the frontend during startup. +struct PostgreSQLDiagnosticResponse: Decodable, Error { + /// The diagnostic messages. + var fields: [PostgreSQLDiagnosticType: String] + + /// See Decodable.init + init(from decoder: Decoder) throws { + fields = [:] + let single = try decoder.singleValueContainer() + parse: while true { + let type = try single.decode(PostgreSQLDiagnosticType.self) + switch type { + case .end: break parse + default: + assert(fields[type] == nil) + fields[type] = try single.decode(String.self) + } + } + } +} + +extension PostgreSQLDiagnosticResponse: Debuggable { + /// See `Debuggable.readableName` + static var readableName: String { + return "PostgreSQL Diagnostic" + } + + /// See `Debuggable.reason` + var reason: String { + return (fields[.localizedSeverity] ?? "ERROR") + ": " + (fields[.message] ?? "Unknown") + } + + /// See `Debuggable.identifier` + var identifier: String { + return fields[.routine] ?? fields[.sqlState] ?? "unknown" + } + + /// See `Helpable.possibleCauses` + var possibleCauses: [String] { + var strings: [String] = [] + if let message = fields[.message] { + strings.append(message) + } + return strings + } + + /// See `Helpable.suggestedFixes` + var suggestedFixes: [String] { + var strings: [String] = [] + if let hint = fields[.hint] { + strings.append(hint) + } + return strings + } +} + +enum PostgreSQLDiagnosticType: Byte, Decodable, Hashable { + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + //// localized translation of one of these. Always present. + case localizedSeverity = 0x53 /// S + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). + /// This is identical to the S field except that the contents are never localized. + /// This is present only in messages generated by PostgreSQL versions 9.6 and later. + case severity = 0x56 /// V + /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. + case sqlState = 0x43 /// C + /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). + /// Always present. + case message = 0x4D /// M + /// Detail: an optional secondary error message carrying more detail about the problem. + /// Might run to multiple lines. + case detail = 0x44 /// D + /// Hint: an optional suggestion what to do about the problem. + /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) + /// rather than hard facts. Might run to multiple lines. + case hint = 0x48 /// H + /// Position: the field value is a decimal ASCII integer, indicating an error cursor + /// position as an index into the original query string. The first character has index 1, + /// and positions are measured in characters not bytes. + case position = 0x50 /// P + /// Internal position: this is defined the same as the P field, but it is used when the + /// cursor position refers to an internally generated command rather than the one submitted by the client. + /// The q field will always appear when this field appears. + case internalPosition = 0x70 /// p + /// Internal query: the text of a failed internally-generated command. + /// This could be, for example, a SQL query issued by a PL/pgSQL function. + case internalQuery = 0x71 /// q + /// Where: an indication of the context in which the error occurred. + /// Presently this includes a call stack traceback of active procedural language functions and + /// internally-generated queries. The trace is one entry per line, most recent first. + case locationContext = 0x57 /// W + /// Schema name: if the error was associated with a specific database object, the name of + /// the schema containing that object, if any. + case schemaName = 0x73 /// s + /// Table name: if the error was associated with a specific table, the name of the table. + /// (Refer to the schema name field for the name of the table's schema.) + case tableName = 0x74 /// t + /// Column name: if the error was associated with a specific table column, the name of the column. + /// (Refer to the schema and table name fields to identify the table.) + case columnName = 0x63 /// c + /// Data type name: if the error was associated with a specific data type, the name of the data type. + /// (Refer to the schema name field for the name of the data type's schema.) + case dataTypeName = 0x64 /// d + /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. + /// Refer to fields listed above for the associated table or domain. (For this purpose, indexes are + /// treated as constraints, even if they weren't created with constraint syntax.) + case constraintName = 0x6E /// n + /// File: the file name of the source-code location where the error was reported. + case file = 0x46 /// F + /// Line: the line number of the source-code location where the error was reported. + case line = 0x4C /// L + /// Routine: the name of the source-code routine reporting the error. + case routine = 0x52 /// R + /// No more types. + case end = 0x00 +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLExecuteRequest.swift b/Sources/PostgreSQL/Message/PostgreSQLExecuteRequest.swift new file mode 100644 index 00000000..f94a01eb --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLExecuteRequest.swift @@ -0,0 +1,9 @@ +/// Identifies the message as an Execute command. +struct PostgreSQLExecuteRequest: Encodable { + /// The name of the destination portal (an empty string selects the unnamed portal). + var portalName: String + + /// Maximum number of rows to return, if portal contains a query that + /// returns rows (ignored otherwise). Zero denotes β€œno limit”. + var maxRows: Int32 +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLMessage.swift b/Sources/PostgreSQL/Message/PostgreSQLMessage.swift new file mode 100644 index 00000000..0fdcf522 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLMessage.swift @@ -0,0 +1,46 @@ +import Bits + +/// A frontend or backend PostgreSQL message. +enum PostgreSQLMessage { + case startupMessage(PostgreSQLStartupMessage) + /// Identifies the message as an error. + case error(PostgreSQLDiagnosticResponse) + /// Identifies the message as a notice. + case notice(PostgreSQLDiagnosticResponse) + /// One of the various authentication request message formats. + case authenticationRequest(PostgreSQLAuthenticationRequest) + /// Identifies the message as a password response. + case password(PostgreSQLPasswordMessage) + /// Identifies the message as a run-time parameter status report. + case parameterStatus(PostgreSQLParameterStatus) + /// Identifies the message as cancellation key data. The frontend must save these values if it wishes to be able to issue CancelRequest messages later. + case backendKeyData(PostgreSQLBackendKeyData) + /// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. + case readyForQuery(PostgreSQLReadyForQuery) + /// Identifies the message as a simple query. + case query(PostgreSQLQuery) + /// Identifies the message as a row description. + case rowDescription(PostgreSQLRowDescription) + /// Identifies the message as a data row. + case dataRow(PostgreSQLDataRow) + /// Identifies the message as a command-completed response. + case close(PostgreSQLCloseResponse) + /// Identifies the message as a Parse command. + case parse(PostgreSQLParseRequest) + /// Identifies the message as a parameter description. + case parameterDescription(PostgreSQLParameterDescription) + /// Identifies the message as a Bind command. + case bind(PostgreSQLBindRequest) + /// Identifies the message as a Describe command. + case describe(PostgreSQLDescribeRequest) + /// Identifies the message as an Execute command. + case execute(PostgreSQLExecuteRequest) + /// Identifies the message as a Sync command. + case sync + /// Identifies the message as a Parse-complete indicator. + case parseComplete + /// Identifies the message as a Bind-complete indicator. + case bindComplete + /// Identifies the message as a no-data indicator. + case noData +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLParameterDescription.swift b/Sources/PostgreSQL/Message/PostgreSQLParameterDescription.swift new file mode 100644 index 00000000..b849e8dc --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLParameterDescription.swift @@ -0,0 +1,5 @@ +/// Identifies the message as a parameter description. +struct PostgreSQLParameterDescription: Decodable { + /// Specifies the object ID of the parameter data type. + var dataTypes: [PostgreSQLDataType] +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLParameterStatus.swift b/Sources/PostgreSQL/Message/PostgreSQLParameterStatus.swift new file mode 100644 index 00000000..f25ba11e --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLParameterStatus.swift @@ -0,0 +1,14 @@ +struct PostgreSQLParameterStatus: Decodable { + /// The name of the run-time parameter being reported. + var parameter: String + + /// The current value of the parameter. + var value: String +} + +extension PostgreSQLParameterStatus: CustomStringConvertible { + /// CustomStringConvertible.description + var description: String { + return "\(parameter): \(value)" + } +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLParameters.swift b/Sources/PostgreSQL/Message/PostgreSQLParameters.swift new file mode 100644 index 00000000..afa77c61 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLParameters.swift @@ -0,0 +1,27 @@ +/// Represents [String: String] parameters encoded +/// as a list of strings separated by null terminators +/// and finished by a single null terminator. +struct PostgreSQLParameters: Codable, ExpressibleByDictionaryLiteral { + /// The internal parameter storage. + var storage: [String: String] + + /// See Encodable.encode + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + for (key, val) in storage { + try container.encode(key) + try container.encode(val) + } + try container.encode("") + } + + /// See ExpressibleByDictionaryLiteral.init + init(dictionaryLiteral elements: (String, String)...) { + var storage = [String: String]() + for (key, val) in elements { + storage[key] = val + } + self.storage = storage + } +} + diff --git a/Sources/PostgreSQL/Message/PostgreSQLParseRequest.swift b/Sources/PostgreSQL/Message/PostgreSQLParseRequest.swift new file mode 100644 index 00000000..a3d4be8a --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLParseRequest.swift @@ -0,0 +1,14 @@ +/// Identifies the message as a Parse command. +struct PostgreSQLParseRequest: Encodable { + /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). + var statementName: String + + /// The query string to be parsed. + var query: String + + /// The number of parameter data types specified (can be zero). + /// Note that this is not an indication of the number of parameters that might appear in the + /// query string, only the number that the frontend wants to prespecify types for. + /// Specifies the object ID of the parameter data type. Placing a zero here is equivalent to leaving the type unspecified. + var parameterTypes: [PostgreSQLDataType] +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLPasswordMessage.swift b/Sources/PostgreSQL/Message/PostgreSQLPasswordMessage.swift new file mode 100644 index 00000000..e172457e --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLPasswordMessage.swift @@ -0,0 +1,7 @@ +/// Identifies the message as a password response. Note that this is also used for +/// GSSAPI and SSPI response messages (which is really a design error, since the contained +/// data is not a null-terminated string in that case, but can be arbitrary binary data). +struct PostgreSQLPasswordMessage: Encodable { + /// The password (encrypted, if requested). + var password: String +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLQuery.swift b/Sources/PostgreSQL/Message/PostgreSQLQuery.swift new file mode 100644 index 00000000..c1965bb5 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLQuery.swift @@ -0,0 +1,11 @@ +/// Identifies the message as a simple query. +struct PostgreSQLQuery: Encodable { + /// The query string itself. + var query: String + + /// See Encodable.encode + func encode(to encoder: Encoder) throws { + var single = encoder.singleValueContainer() + try single.encode(query) + } +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLReadyForQuery.swift b/Sources/PostgreSQL/Message/PostgreSQLReadyForQuery.swift new file mode 100644 index 00000000..5015f7b9 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLReadyForQuery.swift @@ -0,0 +1,23 @@ +import Bits + +/// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. +struct PostgreSQLReadyForQuery: Decodable { + /// Current backend transaction status indicator. + /// Possible values are 'I' if idle (not in a transaction block); + /// 'T' if in a transaction block; or 'E' if in a failed transaction block + /// (queries will be rejected until block is ended). + var transactionStatus: Byte + + /// See Decodable.decode + init(from decoder: Decoder) throws { + self.transactionStatus = try decoder.singleValueContainer().decode(Byte.self) + } +} + +extension PostgreSQLReadyForQuery: CustomStringConvertible { + /// CustomStringConvertible.description + var description: String { + let char = String(bytes: [transactionStatus], encoding: .ascii) ?? "n/a" + return "transactionStatus: \(char)" + } +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLRowDescription.swift b/Sources/PostgreSQL/Message/PostgreSQLRowDescription.swift new file mode 100644 index 00000000..f02a3bd3 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLRowDescription.swift @@ -0,0 +1,67 @@ +/// Identifies the message as a row description. +struct PostgreSQLRowDescription: Decodable { + /// The fields supplied in the row description. + var fields: [PostgreSQLRowDescriptionField] + + /// See Decodable.decode + init(from decoder: Decoder) throws { + let single = try decoder.singleValueContainer() + + /// Specifies the number of fields in a row (can be zero). + let fieldCount = try single.decode(Int16.self) + var fields: [PostgreSQLRowDescriptionField] = [] + for _ in 0.. [String: PostgreSQLData] { + return try .init(uniqueKeysWithValues: fields.enumerated().map { (i, field) in + let formatCode: PostgreSQLFormatCode + switch formatCodes.count { + case 0: formatCode = .text + case 1: formatCode = formatCodes[0] + default: formatCode = formatCodes[i] + } + let key = field.name + let value = try data.columns[i].parse(dataType: field.dataType, format: formatCode) + return (key, value) + }) + } +} + +/// MARK: Field + +/// Describes a single field returns in a `RowDescription` message. +struct PostgreSQLRowDescriptionField: Decodable { + /// The field name. + var name: String + + /// If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. + var tableObjectID: Int32 + + /// If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. + var columnAttributeNumber: Int16 + + /// The object ID of the field's data type. + var dataType: PostgreSQLDataType + + /// The data type size (see pg_type.typlen). Note that negative values denote variable-width types. + var dataTypeSize: Int16 + + /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + var dataTypeModifier: Int32 + + /// The format code being used for the field. + /// Currently will be zero (text) or one (binary). + /// In a RowDescription returned from the statement variant of Describe, + /// the format code is not yet known and will always be zero. + var formatCode: PostgreSQLFormatCode +} diff --git a/Sources/PostgreSQL/Message/PostgreSQLStartupMessage.swift b/Sources/PostgreSQL/Message/PostgreSQLStartupMessage.swift new file mode 100644 index 00000000..fb89cd77 --- /dev/null +++ b/Sources/PostgreSQL/Message/PostgreSQLStartupMessage.swift @@ -0,0 +1,31 @@ +/// First message sent from the frontend during startup. +struct PostgreSQLStartupMessage: Encodable { + /// Creates a `PostgreSQLStartupMessage` with "3.0" as the protocol version. + static func versionThree(parameters: PostgreSQLParameters) -> PostgreSQLStartupMessage { + return .init(protocolVersion: 196608, parameters: parameters) + } + + /// The protocol version number. The most significant 16 bits are the major + /// version number (3 for the protocol described here). The least significant + /// 16 bits are the minor version number (0 for the protocol described here). + var protocolVersion: Int32 + + /// The protocol version number is followed by one or more pairs of parameter + /// name and value strings. A zero byte is required as a terminator after + /// the last name/value pair. Parameters can appear in any order. user is required, + /// others are optional. Each parameter is specified as: + var parameters: PostgreSQLParameters + + /// Creates a new `PostgreSQLStartupMessage`. + init(protocolVersion: Int32, parameters: PostgreSQLParameters) { + self.protocolVersion = protocolVersion + self.parameters = parameters + } + + /// See Encodable.encode + func encode(to encoder: Encoder) throws { + var single = encoder.singleValueContainer() + try single.encode(protocolVersion) + try single.encode(parameters) + } +} diff --git a/Sources/PostgreSQL/PostgreSQLError.swift b/Sources/PostgreSQL/PostgreSQLError.swift new file mode 100644 index 00000000..e58c182f --- /dev/null +++ b/Sources/PostgreSQL/PostgreSQLError.swift @@ -0,0 +1,29 @@ +import Debugging +import COperatingSystem + +/// Errors that can be thrown while working with PostgreSQL. +public struct PostgreSQLError: Debuggable { + public static let readableName = "PostgreSQL Error" + public let identifier: String + public var reason: String + public var sourceLocation: SourceLocation + public var stackTrace: [String] + public var possibleCauses: [String] + public var suggestedFixes: [String] + + /// Create a new TCP error. + public init( + identifier: String, + reason: String, + possibleCauses: [String] = [], + suggestedFixes: [String] = [], + source: SourceLocation + ) { + self.identifier = identifier + self.reason = reason + self.sourceLocation = source + self.stackTrace = PostgreSQLError.makeStackTrace() + self.possibleCauses = possibleCauses + self.suggestedFixes = suggestedFixes + } +} diff --git a/Sources/PostgreSQL/PostgreSQLProvider.swift b/Sources/PostgreSQL/PostgreSQLProvider.swift new file mode 100644 index 00000000..d2560b3b --- /dev/null +++ b/Sources/PostgreSQL/PostgreSQLProvider.swift @@ -0,0 +1,38 @@ +import Service + +/// Provides base `PostgreSQL` services such as database and connection. +public final class PostgreSQLProvider: Provider { + /// See `Provider.repositoryName` + public static let repositoryName = "fluent-postgresql" + + /// Creates a new `PostgreSQLProvider`. + public init() {} + + /// See `Provider.register` + public func register(_ services: inout Services) throws { + try services.register(DatabaseKitProvider()) + services.register(PostgreSQLDatabaseConfig.self) + services.register(PostgreSQLDatabase.self) + var databases = DatabaseConfig() + databases.add(database: PostgreSQLDatabase.self, as: .psql) + services.register(databases) + } + + /// See `Provider.boot` + public func boot(_ worker: Container) throws { } +} + +/// MARK: Services + +extension PostgreSQLDatabaseConfig: ServiceType { + /// See `ServiceType.makeService(for:)` + public static func makeService(for worker: Container) throws -> PostgreSQLDatabaseConfig { + return .default() + } +} +extension PostgreSQLDatabase: ServiceType { + /// See `ServiceType.makeService(for:)` + public static func makeService(for worker: Container) throws -> PostgreSQLDatabase { + return try .init(config: worker.make(for: PostgreSQLDatabase.self)) + } +} diff --git a/Sources/PostgreSQL/Utilities.swift b/Sources/PostgreSQL/Utilities.swift new file mode 100644 index 00000000..173a57c6 --- /dev/null +++ b/Sources/PostgreSQL/Utilities.swift @@ -0,0 +1,79 @@ +import Bits +import Foundation + +extension Data { + public var hexDebug: String { + return "0x" + map { String(format: "%02X", $0) }.joined(separator: " ") + } +} + +extension UnsafeBufferPointer { + public var unsafeBaseAddress: UnsafePointer { + guard let baseAddress = self.baseAddress else { + fatalError("Unexpected nil baseAddress for \(self)") + } + return baseAddress + } +} + +extension UnsafeRawBufferPointer { + public var unsafeBaseAddress: UnsafeRawPointer { + guard let baseAddress = self.baseAddress else { + fatalError("Unexpected nil baseAddress for \(self)") + } + return baseAddress + } +} + +extension Data { + internal mutating func unsafePopFirst() -> Byte { + guard let byte = popFirst() else { + fatalError("Unexpected end of data") + } + return byte + } + + internal mutating func skip(_ n: Int) { + guard n < count else { + self = Data() + return + } + for _ in 0..(sizeOf: T.Type) { + skip(MemoryLayout.size) + } + + /// Casts data to a supplied type. + internal mutating func extract(_ type: T.Type = T.self) -> T { + assert(MemoryLayout.size <= count, "Insufficient data to exctract: \(T.self)") + defer { skip(sizeOf: T.self) } + return withUnsafeBytes { (pointer: UnsafePointer) -> T in + return pointer.pointee + } + } + + internal mutating func extract(count: Int) -> Data { + assert(self.count >= count, "Insufficient data to extract bytes.") + defer { skip(count) } + return withUnsafeBytes({ (pointer: UnsafePointer) -> Data in + let buffer = UnsafeBufferPointer(start: pointer, count: count) + return Data(buffer) + }) + } +} + + +extension Data { + /// Casts data to a supplied type. + internal func unsafeCast(to type: T.Type = T.self) -> T { + return withUnsafeBytes { (pointer: UnsafePointer) -> T in + return pointer.pointee + } + } + + +} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift new file mode 100644 index 00000000..d11b8d1a --- /dev/null +++ b/Tests/LinuxMain.swift @@ -0,0 +1,7 @@ +import XCTest +@testable import PostgreSQLTests + +XCTMain([ + testCase(PostgreSQLConnectionTests.allTests), + testCase(PostgreSQLMessageTests.allTests), +]) diff --git a/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift b/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift new file mode 100644 index 00000000..8f98521a --- /dev/null +++ b/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift @@ -0,0 +1,312 @@ +import Async +import Foundation +import XCTest +import PostgreSQL +import TCP + +class PostgreSQLConnectionTests: XCTestCase { + func testVersion() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + let results = try client.simpleQuery("SELECT version();").await(on: eventLoop) + try XCTAssert(results[0]["version"]?.decode(String.self).contains("10.") == true) + } + + func testSelectTypes() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + let results = try client.query("select * from pg_type;").await(on: eventLoop) + if results.count > 350 { + let name = try results[128]["typname"]?.decode(String.self) + XCTAssert(name != "") + } else { + XCTFail("Results count not large enough: \(results.count)") + } + } + + func testParse() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + let query = """ + select * from "pg_type" where "typlen" = $1 or "typlen" = $2 + """ + let rows = try client.query(query, [1, 2]).await(on: eventLoop) + + for row in rows { + try XCTAssert(row["typlen"]?.decode(Int.self) == 1 || row["typlen"]?.decode(Int.self) == 2) + } + } + + func testTypes() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + let createQuery = """ + create table kitchen_sink ( + "smallint" smallint, + "integer" integer, + "bigint" bigint, + "decimal" decimal, + "numeric" numeric, + "real" real, + "double" double precision, + "varchar" varchar(64), + "char" char(4), + "text" text, + "bytea" bytea, + "timestamp" timestamp, + "date" date, + "time" time, + "boolean" boolean, + "point" point + -- "line" line, + -- "lseg" lseg, + -- "box" box, + -- "path" path, + -- "polygon" polygon, + -- "circle" circle, + -- "cidr" cidr, + -- "inet" inet, + -- "macaddr" macaddr, + -- "bit" bit(16), + -- "uuid" uuid + ); + """ + _ = try client.query("drop table if exists kitchen_sink;").await(on: eventLoop) + let createResult = try client.query(createQuery).await(on: eventLoop) + XCTAssertEqual(createResult.count, 0) + + let insertQuery = """ + insert into kitchen_sink values ( + 1, -- "smallint" smallint + 2, -- "integer" integer + 3, -- "bigint" bigint + 4, -- "decimal" decimal + 5.3, -- "numeric" numeric + 6, -- "real" real + 7, -- "double" double precision + '9', -- "varchar" varchar(64) + '10', -- "char" char(4) + '11', -- "text" text + '12', -- "bytea" bytea + now(), -- "timestamp" timestamp + current_date, -- "date" date + localtime, -- "time" time + true, -- "boolean" boolean + point(13.5,14) -- "point" point, + -- "line" line, + -- "lseg" lseg, + -- "box" box, + -- "path" path, + -- "polygon" polygon, + -- "circle" circle, + -- "cidr" cidr, + -- "inet" inet, + -- "macaddr" macaddr, + -- "bit" bit(16), + -- "uuid" uuid + ); + """ + let insertResult = try! client.query(insertQuery).await(on: eventLoop) + XCTAssertEqual(insertResult.count, 0) + let queryResult = try client.query("select * from kitchen_sink").await(on: eventLoop) + if queryResult.count == 1 { + let row = queryResult[0] + try XCTAssertEqual(row["smallint"]?.decode(Int16.self), 1) + try XCTAssertEqual(row["integer"]?.decode(Int32.self), 2) + try XCTAssertEqual(row["bigint"]?.decode(Int64.self), 3) + try XCTAssertEqual(row["decimal"]?.decode(String.self), "4") + try XCTAssertEqual(row["real"]?.decode(Float.self), 6) + try XCTAssertEqual(row["double"]?.decode(Double.self), 7) + try XCTAssertEqual(row["varchar"]?.decode(String.self), "9") + try XCTAssertEqual(row["char"]?.decode(String.self), "10 ") + try XCTAssertEqual(row["text"]?.decode(String.self), "11") + try XCTAssertEqual(row["bytea"]?.decode(Data.self), Data([0x31, 0x32])) + try XCTAssertEqual(row["boolean"]?.decode(Bool.self), true) + try XCTAssertNotNil(row["timestamp"]?.decode(Date.self)) + try XCTAssertNotNil(row["date"]?.decode(Date.self)) + try XCTAssertNotNil(row["time"]?.decode(Date.self)) + try XCTAssertEqual(row["point"]?.decode(PostgreSQLPoint.self), PostgreSQLPoint(x: 13.5, y: 14)) + } else { + XCTFail("query result count is: \(queryResult.count)") + } + } + + func testParameterizedTypes() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + let createQuery = """ + create table kitchen_sink ( + "smallint" smallint, + "integer" integer, + "bigint" bigint, + "decimal" decimal, + "numeric" numeric, + "real" real, + "double" double precision, + "varchar" varchar(64), + "char" char(4), + "text" text, + "bytea" bytea, + "timestamp" timestamp, + "date" date, + "time" time, + "boolean" boolean, + "point" point, + "uuid" uuid, + "array" point[] + -- "line" line, + -- "lseg" lseg, + -- "box" box, + -- "path" path, + -- "polygon" polygon, + -- "circle" circle, + -- "cidr" cidr, + -- "inet" inet, + -- "macaddr" macaddr, + -- "bit" bit(16), + ); + """ + _ = try client.query("drop table if exists kitchen_sink;").await(on: eventLoop) + let createResult = try client.query(createQuery).await(on: eventLoop) + XCTAssertEqual(createResult.count, 0) + + let insertQuery = """ + insert into kitchen_sink values ( + $1, -- "smallint" smallint + $2, -- "integer" integer + $3, -- "bigint" bigint + $4::numeric, -- "decimal" decimal + $5, -- "numeric" numeric + $6, -- "real" real + $7, -- "double" double precision + $8, -- "varchar" varchar(64) + $9, -- "char" char(4) + $10, -- "text" text + $11, -- "bytea" bytea + $12, -- "timestamp" timestamp + $13, -- "date" date + $14, -- "time" time + $15, -- "boolean" boolean + $16, -- "point" point + $17, -- "uuid" uuid + '{"(1,2)","(3,4)"}' -- "array" point[] + -- "line" line, + -- "lseg" lseg, + -- "box" box, + -- "path" path, + -- "polygon" polygon, + -- "circle" circle, + -- "cidr" cidr, + -- "inet" inet, + -- "macaddr" macaddr, + -- "bit" bit(16), + ); + """ + + var params: [PostgreSQLDataCustomConvertible] = [] + params += Int16(1) // smallint + params += Int32(2) // integer + params += Int64(3) // bigint + params += String("123456789.123456789") // decimal + params += Double(5) // numeric + params += Float(6) // real + params += Double(7) // double + params += String("8") // varchar + params += String("9") // char + params += String("10") // text + params += Data([0x31, 0x32]) // bytea + params += Date() // timestamp + params += Date() // date + params += Date() // time + params += Bool(true) // boolean + params += PostgreSQLPoint(x: 11.4, y: 12) // point + params += UUID() // new uuid + // params.append([1,2,3] as [Int]) // new array + + let insertResult = try! client.query(insertQuery, params).await(on: eventLoop) + XCTAssertEqual(insertResult.count, 0) + + let parameterizedResult = try! client.query("select * from kitchen_sink").await(on: eventLoop) + if parameterizedResult.count == 1 { + let row = parameterizedResult[0] + try XCTAssertEqual(row["smallint"]?.decode(Int16.self), 1) + try XCTAssertEqual(row["integer"]?.decode(Int32.self), 2) + try XCTAssertEqual(row["bigint"]?.decode(Int64.self), 3) + try XCTAssertEqual(row["decimal"]?.decode(String.self), "123456789.123456789") + try XCTAssertEqual(row["real"]?.decode(Float.self), 6) + try XCTAssertEqual(row["double"]?.decode(Double.self), 7) + try XCTAssertEqual(row["varchar"]?.decode(String.self), "8") + try XCTAssertEqual(row["char"]?.decode(String.self), "9 ") + try XCTAssertEqual(row["text"]?.decode(String.self), "10") + try XCTAssertEqual(row["bytea"]?.decode(Data.self), Data([0x31, 0x32])) + try XCTAssertEqual(row["boolean"]?.decode(Bool.self), true) + try XCTAssertNotNil(row["timestamp"]?.decode(Date.self)) + try XCTAssertNotNil(row["date"]?.decode(Date.self)) + try XCTAssertNotNil(row["time"]?.decode(Date.self)) + try XCTAssertEqual(row["point"]?.decode(String.self), "(11.4,12.0)") + try XCTAssertNotNil(row["uuid"]?.decode(UUID.self)) + try XCTAssertEqual(row["array"]?.decode([PostgreSQLPoint].self).first?.x, 1.0) + } else { + XCTFail("parameterized result count is: \(parameterizedResult.count)") + } + } + + func testStruct() throws { + struct Hello: PostgreSQLJSONCustomConvertible { + var message: String + } + + let (client, eventLoop) = try! PostgreSQLConnection.makeTest() + _ = try! client.query("drop table if exists foo;").await(on: eventLoop) + let createResult = try! client.query("create table foo (id integer, dict jsonb);").await(on: eventLoop) + XCTAssertEqual(createResult.count, 0) + let insertResult = try! client.query("insert into foo values ($1, $2);", [ + Int32(1), Hello(message: "hello, world") + ]).await(on: eventLoop) + + XCTAssertEqual(insertResult.count, 0) + let parameterizedResult = try! client.query("select * from foo").await(on: eventLoop) + if parameterizedResult.count == 1 { + let row = parameterizedResult[0] + try! XCTAssertEqual(row["id"]?.decode(Int.self), 1) + try! XCTAssertEqual(row["dict"]?.decode(Hello.self).message, "hello, world") + } else { + XCTFail("parameterized result count is: \(parameterizedResult.count)") + } + } + + func testNull() throws { + let (client, eventLoop) = try PostgreSQLConnection.makeTest() + _ = try client.query("drop table if exists nulltest;").await(on: eventLoop) + let createResult = try client.query("create table nulltest (i integer not null, d timestamp);").await(on: eventLoop) + XCTAssertEqual(createResult.count, 0) + let insertResult = try! client.query("insert into nulltest (i, d) VALUES ($1, $2)", [ + PostgreSQLData(type: .int2, format: .binary, data: Data([0x00, 0x01])), + PostgreSQLData(type: .timestamp, format: .binary, data: nil), + ]).await(on: eventLoop) + XCTAssertEqual(insertResult.count, 0) + let parameterizedResult = try! client.query("select * from nulltest").await(on: eventLoop) + XCTAssertEqual(parameterizedResult.count, 1) + } + + static var allTests = [ + ("testVersion", testVersion), + ("testSelectTypes", testSelectTypes), + ("testParse", testParse), + ("testTypes", testTypes), + ("testParameterizedTypes", testParameterizedTypes), + ("testStruct", testStruct), + ("testNull", testNull), + ] +} + +extension PostgreSQLConnection { + /// Creates a test event loop and psql client. + static func makeTest() throws -> (PostgreSQLConnection, EventLoop) { + let eventLoop = try DefaultEventLoop(label: "codes.vapor.postgresql.client.test") + let client = try PostgreSQLConnection.connect(on: eventLoop) { _, error in + XCTFail("\(error)") + } + _ = try client.authenticate(username: "postgres").await(on: eventLoop) + return (client, eventLoop) + } +} + +func +=(lhs: inout [T], rhs: T) { + lhs.append(rhs) +} diff --git a/Tests/PostgreSQLTests/PostgreSQLMessageTests.swift b/Tests/PostgreSQLTests/PostgreSQLMessageTests.swift new file mode 100644 index 00000000..7a889de1 --- /dev/null +++ b/Tests/PostgreSQLTests/PostgreSQLMessageTests.swift @@ -0,0 +1,15 @@ +import Foundation +import XCTest +@testable import PostgreSQL + +class PostgreSQLMessageTests: XCTestCase { + func testExample() throws { + let startup = PostgreSQLStartupMessage.versionThree(parameters: ["user": "tanner"]) + let data = try PostgreSQLMessageEncoder().encode(.startupMessage(startup)) + XCTAssertEqual(data.hexDebug, "0x00 00 00 15 00 03 00 00 75 73 65 72 00 74 61 6E 6E 65 72 00 00") + } + + static var allTests = [ + ("testExample", testExample), + ] +} diff --git a/circle.yml b/circle.yml new file mode 100644 index 00000000..ffbf8bd4 --- /dev/null +++ b/circle.yml @@ -0,0 +1,29 @@ +version: 2 + +jobs: + macos: + macos: + xcode: "9.2" + steps: + - checkout + - run: swift build + - run: swift test + linux: + docker: + - image: norionomura/swift:swift-4.1-branch + - image: circleci/postgres:latest + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: "" + steps: + - checkout + - run: swift build + - run: swift test +workflows: + version: 2 + tests: + jobs: + - linux + # - macos + \ No newline at end of file