diff --git a/main.go b/main.go index 3cf5f7a..8ff3a5a 100644 --- a/main.go +++ b/main.go @@ -83,7 +83,9 @@ func main() { os.Exit(1) } - client := &http.Client{} + client := &http.Client{ + Timeout: time.Second * 10, + } token, err := getAcessToken(client, username) if err != nil { @@ -117,7 +119,7 @@ func main() { } tsURLs := make(chan string, 2) - done := make(chan error) + done := make(chan error, 1) go streamTs(client, tsURLs, output, done) var seenURLs []string @@ -136,6 +138,7 @@ func main() { for { select { case err := <-done: + close(tsURLs) stdErr.Printf("error while streaming: %v\n", err) os.Exit(2) default: diff --git a/stream.go b/stream.go index 969f153..6e6460d 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io" "net/http" @@ -9,35 +10,50 @@ import ( func streamTs(c *http.Client, ts <-chan string, out io.Writer, done chan<- error) { for url := range ts { for { - req, err := http.NewRequest("GET", url, nil) - if err != nil { - done <- fmt.Errorf("couldn't create ts request: %w", err) - return - } + err := func() error { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &retryError{fmt.Errorf("couldn't create ts request: %w", err)} + } - res, err := c.Do(req) - if err != nil { - done <- fmt.Errorf("couldn't get ts: %w", err) - return - } + res, err := c.Do(req) + if err != nil { + return &retryError{fmt.Errorf("couldn't get ts: %w", err)} + } + defer res.Body.Close() - if res.StatusCode < 200 || res.StatusCode >= 300 { - stdErr.Printf("got non-2xx http status %s, skipping segment\n", res.Status) - break - } + if res.StatusCode < 200 || res.StatusCode >= 300 { + return &skipError{fmt.Errorf("got non-2xx http status %s", res.Status)} + } - _, err = io.Copy(out, res.Body) + _, err = io.Copy(&writerError{out}, &readerError{res.Body}) + if err != nil && !errors.Is(err, io.EOF) { + if wErr, ok := err.(*writeError); ok { + return &fatalError{fmt.Errorf("error while writing ts to output: %w", wErr.Unwrap())} + } + + return &skipError{fmt.Errorf("couldn't copy ts to output: %w", err)} + } + + return nil + }() if err != nil { - if err == io.ErrUnexpectedEOF { - stdErr.Printf("got unexpected EOF while copying ts to output, skipping segment\n") - break + if _, ok := err.(*fatalError); ok { + done <- err + for range ts { + } + return } - done <- fmt.Errorf("couldn't copy ts to output: %w", err) - return + stdErr.Printf("%v\n", err) + if _, ok := err.(*skipError); ok { + break + } + if _, ok := err.(*retryError); ok { + continue + } } - res.Body.Close() break } } diff --git a/stream_err.go b/stream_err.go new file mode 100644 index 0000000..ffd5d9f --- /dev/null +++ b/stream_err.go @@ -0,0 +1,85 @@ +package main + +import ( + "fmt" + "io" +) + +type writeError struct { + Err error +} + +func (err *writeError) Error() string { + return fmt.Sprintf("write: %v", err.Err) +} +func (err *writeError) Unwrap() error { + return err.Err +} + +type writerError struct { + w io.Writer +} + +func (w *writerError) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + if err != nil { + return n, &writeError{err} + } + return n, nil +} + +type readError struct { + Err error +} + +func (err *readError) Error() string { + return fmt.Sprintf("read: %v", err.Err) +} +func (err *readError) Unwrap() error { + return err.Err +} + +type readerError struct { + r io.Reader +} + +func (r *readerError) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + if err != nil { + return n, &readError{err} + } + return n, nil +} + +type retryError struct { + Err error +} + +func (err *retryError) Error() string { + return fmt.Sprintf("retrying segment: %v", err.Err) +} +func (err *retryError) Unwrap() error { + return err.Err +} + +type skipError struct { + Err error +} + +func (err *skipError) Error() string { + return fmt.Sprintf("skipping segment: %v", err.Err) +} +func (err *skipError) Unwrap() error { + return err.Err +} + +type fatalError struct { + Err error +} + +func (err *fatalError) Error() string { + return fmt.Sprintf("fatal error: %v", err.Err) +} +func (err *fatalError) Unwrap() error { + return err.Err +}