Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make providing single primary key optional, remove ~conflict, support enums #3

Merged
merged 4 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.59.1
version: v1.60.1
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ ripoffs define rows to be inserted into your database. Any number of ripoffs can
rows:
# A "users" table row identified with a UUID generated with the seed "fooBar"
users:uuid(fooBar):
# Using the map key here implicitly informs ripoff that "id" is the primary key of the table
id: users:uuid(fooBar)
email: [email protected]
# Note that ripoff will automatically set primary key columns, so you don't need to add:
# id: users:uuid(fooBar)
avatars:uuid(fooBarAvatar):
id: avatars:uuid(fooBarAvatar)
# ripoff will see this and insert the "users:uuid(fooBar)" row before this row
user_id: users:uuid(fooBar)
users:uuid(randomUser):
id: users:uuid(randomUser)
# Generate a random email with the seed "randomUser"
email: email(randomUser)
```
Expand All @@ -49,9 +47,8 @@ valueFuncs allow you to generate random data that's seeded with a static string.

ripoff provides:

- `uuid(seedString)` - generates a UUIDv4
- `uuid(seedString)` - generates a v1 UUID
- `int(seedString)` - generates an integer (note: might be awkward on auto incrementing tables)
- `literal(someId)` - returns "someId" exactly. useful if you want to hard code UUIDs/ints

and also all functions from [gofakeit](https://github.com/brianvoe/gofakeit?tab=readme-ov-file#functions) that have no arguments and return a string (called in camelcase, ex: `email(seedString)`). For the full list, see `./gofakeit.go`.

Expand All @@ -66,12 +63,10 @@ rows:
# "rowId" is the id/key of the row that rendered this template.
# You could also pass an explicit seed and use it like `users:uuid({{ .seed }})`
{{ .rowId }}:
id: {{ .rowId }}
email: {{ .email }}
# It's convenient to use the rowId as a seed to other valueFuncs.
avatar_id: avatars:uuid({{ .rowId }})
avatars:uuid({{ .rowId }}):
id: avatars:uuid({{ .rowId }})
url: {{ .avatarUrl }}
```

Expand All @@ -90,9 +85,10 @@ rows:
avatarGrayscale: false
```

## Explicitly defining primary keys
### Special template variables

ripoff will try to determine the primary key for your row by matching the row ID with a single column (see "Basic example" above). However if you use composite keys, or your primary key is a foreign key to another table (see `./testdata/dependencies`), this may not be possible. In these cases you can manually define primary keys using `~conflict: column_1, column_2, ...`.
- `rowID` - The map key of the row using this template, ex `users:uuid(fooBar)`. Useful for allowing the "caller" to provide their own ID for the "main" row being created, if there is one. Optional to use if you find it awkward.
- `enums` - A map of SQL enums names to an array of enum values. Useful for creating one row for each value of an enum (ex: each user role).

# Security

Expand Down
20 changes: 14 additions & 6 deletions cmd/ripoff/ripoff.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ func main() {
slog.Error("Path to YAML files required")
os.Exit(1)
}
rootDirectory := path.Clean(flag.Arg(0))
totalRipoff, err := ripoff.RipoffFromDirectory(rootDirectory)
if err != nil {
slog.Error("Could not load ripoff", errAttr(err))
os.Exit(1)
}

// Start database transaction.
ctx := context.Background()
conn, err := pgx.Connect(ctx, dburl)
if err != nil {
Expand All @@ -64,6 +59,19 @@ func main() {
}
}()

enums, err := ripoff.GetEnumValues(ctx, tx)
if err != nil {
slog.Error("Could not load enums", errAttr(err))
os.Exit(1)
}

rootDirectory := path.Clean(flag.Arg(0))
totalRipoff, err := ripoff.RipoffFromDirectory(rootDirectory, enums)
if err != nil {
slog.Error("Could not load ripoff", errAttr(err))
os.Exit(1)
}

err = ripoff.RunRipoff(ctx, tx, totalRipoff)
if err != nil {
slog.Error("Could not run ripoff", errAttr(err))
Expand Down
115 changes: 99 additions & 16 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ import (

// Runs ripoff from start to finish, without committing the transaction.
func RunRipoff(ctx context.Context, tx pgx.Tx, totalRipoff RipoffFile) error {
queries, err := buildQueriesForRipoff(totalRipoff)
primaryKeys, err := getPrimaryKeys(ctx, tx)
if err != nil {
return err
}

queries, err := buildQueriesForRipoff(primaryKeys, totalRipoff)
if err != nil {
return err
}
Expand All @@ -35,6 +40,72 @@ func RunRipoff(ctx context.Context, tx pgx.Tx, totalRipoff RipoffFile) error {
return nil
}

const primaryKeysQuery = `
SELECT STRING_AGG(c.column_name, '|'), tc.table_name
FROM information_schema.table_constraints tc
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema
AND tc.table_name = c.table_name AND ccu.column_name = c.column_name
WHERE constraint_type = 'PRIMARY KEY'
AND tc.table_schema = 'public'
GROUP BY tc.table_name;
`

type PrimaryKeysResult map[string][]string

func getPrimaryKeys(ctx context.Context, tx pgx.Tx) (PrimaryKeysResult, error) {
rows, err := tx.Query(ctx, primaryKeysQuery)
if err != nil {
return nil, err
}
defer rows.Close()

allPrimaryKeys := PrimaryKeysResult{}

for rows.Next() {
var primaryKeys string
var tableName string
err = rows.Scan(&primaryKeys, &tableName)
if err != nil {
return nil, err
}
allPrimaryKeys[tableName] = strings.Split(primaryKeys, "|")
}
return allPrimaryKeys, nil
}

const enumValuesQuery = `
SELECT STRING_AGG(e.enumlabel, '|'), t.typname
FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = 'public'
GROUP BY t.typname;
`

type EnumValuesResult map[string][]string

func GetEnumValues(ctx context.Context, tx pgx.Tx) (EnumValuesResult, error) {
rows, err := tx.Query(ctx, enumValuesQuery)
if err != nil {
return nil, err
}
defer rows.Close()

allEnumValues := EnumValuesResult{}

for rows.Next() {
var primaryKeys string
var tableName string
err = rows.Scan(&primaryKeys, &tableName)
if err != nil {
return nil, err
}
allEnumValues[tableName] = strings.Split(primaryKeys, "|")
}
return allEnumValues, nil
}

var valueFuncRegex = regexp.MustCompile(`([a-zA-Z]+)\((.*)\)$`)
var referenceRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+:`)

Expand Down Expand Up @@ -76,33 +147,45 @@ func prepareValue(rawValue string) (string, error) {
return fakerResult, nil
}

func buildQueryForRow(rowId string, row Row, dependencyGraph graph.Graph[string, string]) (string, error) {
func buildQueryForRow(primaryKeys PrimaryKeysResult, rowId string, row Row, dependencyGraph graph.Graph[string, string]) (string, error) {
parts := strings.Split(rowId, ":")
if len(parts) < 2 {
return "", fmt.Errorf("invalid id: %s", rowId)
}
table := parts[0]
primaryKeysForTable, hasPrimaryKeysForTable := primaryKeys[table]

columns := []string{}
values := []string{}
setStatements := []string{}

onConflictColumn := ""
for column, valueRaw := range row {
// Technically we allow more than strings in ripoff files for templating purposes,
// but full support (ex: escaping arrays, what to do with maps, etc.) is quite hard so tabling that for now.
value := fmt.Sprint(valueRaw)
if hasPrimaryKeysForTable {
quotedKeys := make([]string, len(primaryKeysForTable))
for i, columnPart := range primaryKeysForTable {
quotedKeys[i] = pq.QuoteIdentifier(strings.TrimSpace(columnPart))
}
onConflictColumn = strings.Join(quotedKeys, ", ")
// For UX reasons, you don't have to define primary key columns (ex: id), since we have the map key already.
if len(primaryKeysForTable) == 1 {
column := primaryKeysForTable[0]
_, hasPrimaryColumn := row[column]
if !hasPrimaryColumn {
row[column] = rowId
}
}
}

// Rows can explicitly mark what columns they should conflict with, in cases like composite primary keys.
for column, valueRaw := range row {
// Backwards compatability weirdness.
if column == "~conflict" {
// Really novice way of escaping these.
columnParts := strings.Split(value, ",")
for i, columnPart := range columnParts {
columnParts[i] = pq.QuoteIdentifier(strings.TrimSpace(columnPart))
}
onConflictColumn = strings.Join(columnParts, ", ")
continue
}

// Technically we allow more than strings in ripoff files for templating purposes,
// but full support (ex: escaping arrays, what to do with maps, etc.) is quite hard so tabling that for now.
value := fmt.Sprint(valueRaw)

// Assume that if a valueFunc is prefixed with a table name, it's a primary/foreign key.
addEdge := referenceRegex.MatchString(value)
// Don't add edges to and from the same row.
Expand All @@ -119,7 +202,7 @@ func buildQueryForRow(rowId string, row Row, dependencyGraph graph.Graph[string,
return "", err
}
// Assume this column is the primary key.
if rowId == value {
if rowId == value && onConflictColumn == "" {
onConflictColumn = pq.QuoteIdentifier(column)
}
values = append(values, pq.QuoteLiteral(valuePrepared))
Expand All @@ -145,7 +228,7 @@ func buildQueryForRow(rowId string, row Row, dependencyGraph graph.Graph[string,
}

// Returns a sorted array of queries to run based on a given ripoff file.
func buildQueriesForRipoff(totalRipoff RipoffFile) ([]string, error) {
func buildQueriesForRipoff(primaryKeys PrimaryKeysResult, totalRipoff RipoffFile) ([]string, error) {
dependencyGraph := graph.New(graph.StringHash, graph.Directed(), graph.Acyclic())
queries := map[string]string{}

Expand All @@ -159,7 +242,7 @@ func buildQueriesForRipoff(totalRipoff RipoffFile) ([]string, error) {

// Build queries.
for rowId, row := range totalRipoff.Rows {
query, err := buildQueryForRow(rowId, row, dependencyGraph)
query, err := buildQueryForRow(primaryKeys, rowId, row, dependencyGraph)
if err != nil {
return []string{}, err
}
Expand Down
4 changes: 3 additions & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func runTestData(t *testing.T, ctx context.Context, tx pgx.Tx, testDir string) {
require.NoError(t, err)
_, err = tx.Exec(ctx, string(schemaFile))
require.NoError(t, err)
totalRipoff, err := RipoffFromDirectory(testDir)
enums, err := GetEnumValues(ctx, tx)
require.NoError(t, err)
totalRipoff, err := RipoffFromDirectory(testDir, enums)
require.NoError(t, err)
err = RunRipoff(ctx, tx, totalRipoff)
require.NoError(t, err)
Expand Down
7 changes: 4 additions & 3 deletions ripoff_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var funcMap = template.FuncMap{
var templateFileRegex = regexp.MustCompile(`^template_(\S+)\.`)

// Adds newRows to existingRows, processing templated rows when needed.
func concatRows(templates *template.Template, existingRows map[string]Row, newRows map[string]Row) error {
func concatRows(templates *template.Template, existingRows map[string]Row, newRows map[string]Row, enums EnumValuesResult) error {
for rowId, row := range newRows {
_, rowExists := existingRows[rowId]
if rowExists {
Expand All @@ -44,6 +44,7 @@ func concatRows(templates *template.Template, existingRows map[string]Row, newRo
// Templates can additionally use it to seed random generators.
templateVars := row
templateVars["rowId"] = rowId
templateVars["enums"] = enums
buf := &bytes.Buffer{}
err := templates.ExecuteTemplate(buf, templateName, templateVars)
if err != nil {
Expand All @@ -69,7 +70,7 @@ func concatRows(templates *template.Template, existingRows map[string]Row, newRo
}

// Builds a single RipoffFile from a directory of yaml files.
func RipoffFromDirectory(dir string) (RipoffFile, error) {
func RipoffFromDirectory(dir string, enums EnumValuesResult) (RipoffFile, error) {
dir = filepath.Clean(dir)

// Treat files starting with template_ as go templates.
Expand Down Expand Up @@ -116,7 +117,7 @@ func RipoffFromDirectory(dir string) (RipoffFile, error) {
}

for _, ripoff := range allRipoffs {
err = concatRows(templates, totalRipoff.Rows, ripoff.Rows)
err = concatRows(templates, totalRipoff.Rows, ripoff.Rows, enums)
if err != nil {
return RipoffFile{}, err
}
Expand Down
3 changes: 0 additions & 3 deletions testdata/basic/basic.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
rows:
uuid_users:uuid(fooBar):
id: uuid_users:uuid(fooBar)
email: [email protected]
int_users:int(fooBar):
id: int_users:int(fooBar)
email: [email protected]
uuid_users:literal(bbb2ddaa-f33a-4b85-96e7-96d77a194b61):
id: uuid_users:literal(bbb2ddaa-f33a-4b85-96e7-96d77a194b61)
email: [email protected]
1 change: 0 additions & 1 deletion testdata/bigdata/template_multi_user.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
rows:
{{ range $k, $v := (intSlice .numUsers) }}
users:uuid({{ print $.rowId $k }}):
id: users:uuid({{ print $.rowId $k }})
email: multi-user-{{ $k }}@example.com
{{ end }}
4 changes: 0 additions & 4 deletions testdata/dependencies/dependencies.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
rows:
users:uuid(fooBar):
id: users:uuid(fooBar)
email: [email protected]
avatar_id: avatars:uuid(fooBarAvatar)
avatars:uuid(fooBarAvatar):
id: avatars:uuid(fooBarAvatar)
url: image.png
avatar_modifiers:uuid(fooBarAvatar):
~conflict: id
id: avatars:uuid(fooBarAvatar)
grayscale: true
4 changes: 4 additions & 0 deletions testdata/enums/enums.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rows:
workspaces:uuid(myWorkspace):
template: template_workspace.yml
slug: mySlug
18 changes: 18 additions & 0 deletions testdata/enums/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE TYPE user_role AS ENUM ('admin', 'power', 'normie', 'banned');

CREATE TABLE users (
id UUID NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
role user_role NOT NULL
);

CREATE TABLE workspaces (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
slug TEXT NOT NULL
);

CREATE TABLE workspace_memberships (
user_id UUID NOT NULL,
workspace_id UUID NOT NULL,
PRIMARY KEY (user_id, workspace_id)
);
14 changes: 14 additions & 0 deletions testdata/enums/template_workspace.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
rows:
{{ .rowId }}:
slug: {{ .slug }}
# Create a user for every possible role.
# "enums" is a map of SQL enums names to an array of enum values.
{{ range $user_role := .enums.user_role }}
users:uuid({{ print $.rowId $user_role }}):
# ex: [email protected]
email: "{{ $.slug }}+{{ $user_role }}@example.com"
role: {{ $user_role }}
workspace_memberships:uuid({{ print $.rowId $user_role }}):
user_id: users:uuid({{ print $.rowId $user_role }})
workspace_id: {{ $.rowId }}
{{ end }}
8 changes: 8 additions & 0 deletions testdata/enums/validate.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
WITH test AS (
SELECT array_agg(distinct role) as roles FROM users
)
-- db_test.go will automatically determine that the correct number of rows
-- were inserted, but in this case we want to make sure every users row also
-- has a distinct user role.
SELECT case when array_length(roles, 1) = 4 then 1 else 0 end,roles
FROM test;
1 change: 0 additions & 1 deletion testdata/faker/faker.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
rows:
users:uuid(fooBar):
id: users:uuid(fooBar)
email: email(fooBar)
Loading
Loading