diff --git a/main.go b/main.go index 7e2fb29..f1355ab 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,7 @@ var ( staging = flag.Bool("s", false, "use staging CA") upstream = flag.Int("u", 8080, "upstream port") allowedPaths = flag.String("p", "", "Paths to proxy to the upstream server (all if empty)") + headers = flag.String("H", "", "Headers to add upstream") email = "tls@tinfoil.sh" ) @@ -92,8 +93,18 @@ func main() { log.Fatal(err) } + headerPairs := strings.Split(*headers, ",") + headerMap := make(map[string]string, len(headerPairs)) + for _, pair := range headerPairs { + parts := strings.Split(pair, ":") + if len(parts) != 2 { + log.Fatalf("Invalid header: %s", pair) + } + headerMap[parts[0]] = parts[1] + } + paths := strings.Split(*allowedPaths, ",") - log.Printf("Starting SEV-SNP attestation shim %s domain %s paths %s", version, domain, paths) + log.Printf("Starting SEV-SNP attestation shim %s domain %s paths %s headers %+v", version, domain, paths, headers) mux := http.NewServeMux() @@ -129,6 +140,10 @@ func main() { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { cors(w, r) + for k, v := range headerMap { + w.Header().Set(k, v) + } + if len(paths) > 0 { allowed := false for _, path := range paths {