Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Postgres in database.ts #20

Merged
merged 7 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {Replays} from './replays';
import {ActionError, QueryHandler, Server} from './server';
import {toID, updateserver, bash, time, escapeHTML} from './utils';
import * as tables from './tables';
import {SQL} from './database';
import * as pathModule from 'path';
import IPTools from './ip-tools';
import * as crypto from 'crypto';
Expand Down Expand Up @@ -662,9 +663,9 @@ export const actions: {[k: string]: QueryHandler} = {
}
let teams = [];
try {
teams = await tables.pgdb.query(
'SELECT teamid, team, format, title as name FROM teams WHERE ownerid = $1', [this.user.id]
) ?? [];
teams = await tables.teams.selectAll<any>(
SQL`teamid, team, format, title as name`
)`WHERE ownerid = ${this.user.id}`;
} catch (e) {
Server.crashlog(e, 'a teams database query', params);
throw new ActionError('The server could not load your teams. Please try again later.');
Expand Down Expand Up @@ -693,13 +694,13 @@ export const actions: {[k: string]: QueryHandler} = {
throw new ActionError("Invalid team ID");
}
try {
const data = await tables.pgdb.query(
`SELECT ownerid, team, private as privacy FROM teams WHERE teamid = $1`, [teamid]
);
if (!data || !data.length || data[0].ownerid !== this.user.id) {
const data = await tables.teams.selectOne<any>(
SQL`ownerid, team, private as privacy`
)`WHERE teamid = ${teamid}`;
if (!data || data.ownerid !== this.user.id) {
return {team: null};
}
return data[0];
return data;
} catch (e) {
Server.crashlog(e, 'a teams database request', params);
throw new ActionError("Failed to fetch team. Please try again later.");
Expand Down
208 changes: 125 additions & 83 deletions src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ export class SQLStatement {
} else if (value === undefined) {
this.sql[this.sql.length - 1] += nextString;
} else if (Array.isArray(value)) {
if (this.sql[this.sql.length - 1].endsWith(`\``)) {
if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) {
// "`a`, `b`" syntax
const quoteChar = this.sql[this.sql.length - 1].slice(-1);
for (const col of value) {
this.append(col, `\`, \``);
this.append(col, `${quoteChar}, ${quoteChar}`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString;
} else {
Expand All @@ -52,21 +53,21 @@ export class SQLStatement {
}
} else if (this.sql[this.sql.length - 1].endsWith('(')) {
// "(`a`, `b`) VALUES (1, 2)" syntax
this.sql[this.sql.length - 1] += `\``;
this.sql[this.sql.length - 1] += `"`;
for (const col in value) {
this.append(col, `\`, \``);
this.append(col, `", "`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `\`) VALUES (`;
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`;
for (const col in value) {
this.append(value[col], `, `);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString;
} else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) {
// "`a` = 1, `b` = 2" syntax
this.sql[this.sql.length - 1] += `\``;
this.sql[this.sql.length - 1] += `"`;
for (const col in value) {
this.append(col, `\` = `);
this.append(value[col], `, \``);
this.append(col, `" = `);
this.append(value[col], `, "`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString;
} else {
Expand All @@ -83,27 +84,29 @@ export class SQLStatement {
* Tag function for SQL, with some magic.
*
* * `` SQL`UPDATE table SET a = ${'hello"'}` ``
* * `` 'UPDATE table SET a = "hello"' ``
* * `` `UPDATE table SET a = 'hello'` ``
*
* Values surrounded by `` \` `` become names:
* Values surrounded by `"` or `` ` `` become identifiers:
*
* * ``` SQL`SELECT * FROM \`${'table'}\`` ```
* * `` 'SELECT * FROM `table`' ``
* * ``` SQL`SELECT * FROM "${'table'}"` ```
* * `` `SELECT * FROM "table"` ``
*
* (Make sure to use `"` for Postgres and `` ` `` for MySQL.)
*
* Objects preceded by SET become setters:
*
* * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` ``
* * `` 'UPDATE table SET `a` = 1, `b` = 2' ``
* * `` `UPDATE table SET "a" = 1, "b" = 2` ``
*
* Objects surrounded by `()` become keys and values:
*
* * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` ``
* * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' ``
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` ``
*
* Arrays become lists; surrounding by `` \` `` turns them into lists of names:
* Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names:
*
* * `` SQL`INSERT INTO table (\`${['a', 'b']}\`) VALUES (${[1, 2]})` ``
* * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' ``
* * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` ``
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` ``
*/
export function SQL(strings: TemplateStringsArray, ...values: SQLValue[]) {
return new SQLStatement(strings, values);
Expand All @@ -113,53 +116,24 @@ export interface ResultRow {[k: string]: BasicSQLValue}

export const connectedDatabases: Database[] = [];

export class Database {
connection: mysql.Pool;
export abstract class Database<Pool extends mysql.Pool | pg.Pool = mysql.Pool | pg.Pool, OkPacket = unknown> {
connection: Pool;
prefix: string;
constructor(config: mysql.PoolOptions & {prefix?: string}) {
this.prefix = config.prefix || "";
if (config.prefix) {
config = {...config};
delete config.prefix;
}
this.connection = mysql.createPool(config);
constructor(connection: Pool, prefix = '') {
this.prefix = prefix;
this.connection = connection;
connectedDatabases.push(this);
}
resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh this is so much duplicate code though... I think probably it's worth leaving like this.

let sql = query.sql[0];
const values = [];
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`')) {
sql = sql.slice(0, -1) + this.connection.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
sql += '?' + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]];
abstract _query(sql: string, values: BasicSQLValue[]): Promise<any>;
abstract escapeId(param: string): string;
query<T = ResultRow>(sql: SQLStatement): Promise<T[]>;
query<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]>;
query<T = ResultRow>(sql?: SQLStatement) {
if (!sql) return (strings: any, ...rest: any) => this.query<T>(new SQLStatement(strings, rest));

return new Promise<T[]>((resolve, reject) => {
const [query, values] = this.resolveSQL(sql);
this.connection.query(query, values, (e, results: any) => {
if (e) {
return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`));
}
if (Array.isArray(results)) {
for (const row of results) {
for (const col in row) {
if (Buffer.isBuffer(row[col])) row[col] = row[col].toString();
}
}
}
return resolve(results);
});
});
const [query, values] = this._resolveSQL(sql);
return this._query(query, values);
}
queryOne<T = ResultRow>(sql: SQLStatement): Promise<T | undefined>;
queryOne<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined>;
Expand All @@ -168,14 +142,14 @@ export class Database {

return this.query<T>(sql).then(res => Array.isArray(res) ? res[0] : res);
}
queryExec(sql: SQLStatement): Promise<mysql.OkPacket>;
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket>;
queryExec(sql: SQLStatement): Promise<OkPacket>;
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacket>;
queryExec(sql?: SQLStatement) {
if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest));
return this.queryOne<mysql.OkPacket>(sql);
return this.queryOne<OkPacket>(sql);
}
close() {
this.connection.end();
void this.connection.end();
}
}

Expand All @@ -198,7 +172,7 @@ export class DatabaseTable<Row> {
this.primaryKeyName = primaryKeyName;
}
escapeId(param: string) {
return this.db.connection.escapeId(param);
return this.db.escapeId(param);
}

// raw
Expand All @@ -224,45 +198,52 @@ export class DatabaseTable<Row> {
selectAll<T = Row>(entries?: (keyof Row & string)[] | SQLStatement):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]> {
if (!entries) entries = SQL`*`;
if (Array.isArray(entries)) entries = SQL`\`${entries}\``;
if (Array.isArray(entries)) entries = SQL`"${entries}"`;
return (strings, ...rest) =>
this.query<T>()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`;
this.query<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`;
}
selectOne<T = Row>(entries?: (keyof Row & string)[] | SQLStatement):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> {
if (!entries) entries = SQL`*`;
if (Array.isArray(entries)) entries = SQL`\`${entries}\``;
if (Array.isArray(entries)) entries = SQL`"${entries}"`;
return (strings, ...rest) =>
this.queryOne<T>()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`;
this.queryOne<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`;
}
updateAll(partialRow: PartialOrSQL<Row>):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(strings, rest)}`;
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`;
}
updateOne(partialRow: PartialOrSQL<Row>):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (s, ...r) =>
this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`;
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`;
}
deleteAll():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`;
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`;
}
deleteOne():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`;
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`;
}
eval<T>():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> {
return (strings, ...rest) =>
this.queryOne<{result: T}>(
)`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1`
.then(row => row?.result);
}

// high-level

insert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`INSERT INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
insertIgnore(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`INSERT IGNORE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
async tryInsert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
try {
Expand All @@ -279,28 +260,89 @@ export class DatabaseTable<Row> {
return this.replace(partialRow, where);
}
replace(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`REPLACE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) {
return this.selectOne(entries)`WHERE \`${this.primaryKeyName}\` = ${primaryKey}`;
return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`;
}
delete(primaryKey: BasicSQLValue) {
return this.deleteAll()`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`;
return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`;
}
update(primaryKey: BasicSQLValue, data: PartialOrSQL<Row>) {
return this.updateAll(data)`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`;
return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`;
}
}

export class PGDatabase {
database: pg.Pool | null;
constructor(config: pg.PoolConfig | null) {
this.database = config ? new pg.Pool(config) : null;
export class MySQLDatabase extends Database<mysql.Pool, mysql.OkPacket> {
constructor(config: mysql.PoolOptions & {prefix?: string}) {
const prefix = config.prefix || "";
if (config.prefix) {
config = {...config};
delete config.prefix;
}
super(mysql.createPool(config), prefix);
}
async query<O = any>(query: string, values: BasicSQLValue[]) {
if (!this.database) return null;
const result = await this.database.query(query, values);
return result.rows as O[];
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
let sql = query.sql[0];
const values = [];
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) {
sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
sql += '?' + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
override _query(query: string, values: BasicSQLValue[]): Promise<any> {
return new Promise((resolve, reject) => {
this.connection.query(query, values, (e, results: any) => {
if (e) {
return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`));
}
if (Array.isArray(results)) {
for (const row of results) {
for (const col in row) {
if (Buffer.isBuffer(row[col])) row[col] = row[col].toString();
}
}
}
return resolve(results);
});
});
}
override escapeId(id: string) {
return this.connection.escapeId(id);
}
}

export class PGDatabase extends Database<pg.Pool, []> {
constructor(config: pg.PoolConfig) {
super(new pg.Pool(config));
}
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
let sql = query.sql[0];
const values = [];
let paramCount = 0;
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) {
sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
paramCount++;
sql += `$${paramCount}` + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
override _query(query: string, values: BasicSQLValue[]) {
return this.connection.query(query, values).then(res => res.rows);
}
override escapeId(id: string) {
// @ts-expect-error @types/pg really needs to be updated
return pg.escapeIdentifier(id);
}
}
Loading
Loading