-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
149 lines (122 loc) · 3.02 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package main
import (
"bufio"
"encoding/base64"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/hashicorp/vault/shamir"
)
var config struct {
in string
out string
threshold int
parts int
}
func init() {
flag.StringVar(&config.in, "in", "-", "input file")
flag.StringVar(&config.out, "out", "-", "output file")
flag.IntVar(&config.threshold, "threshold", 2, "threshold for split operation")
flag.IntVar(&config.parts, "parts", 3, "number of shares for split operation")
flag.Usage = func() {
fmt.Printf("Usage: %s [<options>] (split|combine)\n", os.Args[0])
flag.PrintDefaults()
}
}
func main() {
flag.Parse()
if err := mainErr(flag.Args()); err != nil {
fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
}
func mainErr(args []string) error {
if len(args) == 0 {
return errors.New("no operation given")
}
if len(args) > 1 {
return errors.New("too many arguments")
}
var operation operation
switch args[0] {
case "combine":
operation = combine
case "split":
operation = split
default:
return fmt.Errorf("invalid operation: %s", args[0])
}
var in io.Reader
if config.in == "-" {
in = os.Stdin
} else {
file, err := os.Open(config.in)
if err != nil {
return fmt.Errorf("opening input file: %s", err)
}
defer func() { _ = file.Close() }()
in = file
}
var out io.Writer
if config.out == "-" {
out = os.Stdout
} else {
file, err := os.Create(config.out)
if err != nil {
return fmt.Errorf("opening out file: %s", err)
}
defer func() { _ = file.Close() }()
out = file
}
return operation(in, out)
}
type operation func(in io.Reader, out io.Writer) error
func combine(in io.Reader, out io.Writer) error {
var shares [][]byte
scanner := bufio.NewScanner(in)
encoding := base64.StdEncoding
for lineno := 1; scanner.Scan(); lineno++ {
line := scanner.Bytes()
share := make([]byte, encoding.DecodedLen(len(line)))
n, err := encoding.Decode(share, line)
if err != nil {
return fmt.Errorf("base64 decoding input: line %d: %s", lineno, err)
}
shares = append(shares, share[:n])
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("reading key shares from input: %s", err)
}
result, err := shamir.Combine(shares)
if err != nil {
return fmt.Errorf("combining key shares: %s", err)
}
if _, err := out.Write(result); err != nil {
return fmt.Errorf("writing secret to output: %s", err)
}
return nil
}
func split(in io.Reader, out io.Writer) error {
secret, err := ioutil.ReadAll(in)
if err != nil {
return fmt.Errorf("reading secret from input: %s", err)
}
shares, err := shamir.Split(secret, config.parts, config.threshold)
if err != nil {
return fmt.Errorf("splitting secret: %s", err)
}
encoding := base64.StdEncoding
for _, share := range shares {
lineLen := encoding.EncodedLen(len(share)) + 1
line := make([]byte, lineLen)
encoding.Encode(line, share)
line[len(line)-1] = '\n'
if _, err := out.Write(line); err != nil {
return fmt.Errorf("writing shares to output: %s", err)
}
}
return nil
}