Skip to content

Commit

Permalink
使用服务端时间
Browse files Browse the repository at this point in the history
  • Loading branch information
orestonce committed Sep 9, 2024
1 parent 9fe6717 commit f0a0801
Show file tree
Hide file tree
Showing 8 changed files with 502 additions and 312 deletions.
24 changes: 23 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (this *DownloadEnv) runDownload(req StartDownload_Req, skipInfo SkipTsInfo)
// 下载ts
this.status.SetProgressBarTitle("[3/4]下载ts")
this.status.SpeedResetBytes()
err = this.downloader(tsList, skipInfo, tsSaveDir, encInfo, req.ThreadCount)
err = this.downloader(tsList, skipInfo, tsSaveDir, encInfo, req)
this.status.SpeedResetBytes()
if err != nil {
this.setErrMsg("下载ts文件错误: " + err.Error())
Expand Down Expand Up @@ -433,6 +433,28 @@ func (this *DownloadEnv) runDownload(req StartDownload_Req, skipInfo SkipTsInfo)
this.setErrMsg("重命名失败: " + err.Error())
return
}
if req.UseServerSideTime && len(tsFileList) > 0 {
var stat os.FileInfo
stat, err = os.Stat(tsFileList[0])
if err != nil {
this.setErrMsg("读取文件状态失败: " + err.Error())
return
}
mTime := stat.ModTime()
err = os.Chtimes(name, mTime, mTime)
if err != nil {
this.setErrMsg("更新文件创建时间失败: " + err.Error())
return
}
this.logToFile("更新mp4创建时间:" + mTime.String())
err = updateMp4CreateTime(name, mTime)
if err != nil {
this.setErrMsg("更新mp4创建时间失败: " + err.Error())
return
}
this.logToFile("更新成功")
}

if skipByHttpCodeLog.Len() > 0 {
// 写入通过http.code跳过的ts文件列表
saveFileName := name + "_" + logFileName
Expand Down
1 change: 1 addition & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func init() {
downloadCmd.Flags().BoolVarP(&gRunReq.Skip_EXT_X_DISCONTINUITY, "Skip_EXT_X_DISCONTINUITY", "", false, "跳过 #EXT-X-DISCONTINUITY 标签包裹的ts")
downloadCmd.Flags().BoolVarP(&gRunReq.DebugLog, "DebugLog", "", false, "调试日志")
downloadCmd.Flags().StringVarP(&gRunReq.TsTempDir, "TsTempDir", "", "", "临时ts文件目录")
downloadCmd.Flags().BoolVarP(&gRunReq.UseServerSideTime, "UseServerSideTime", "", false, "使用服务端提供的文件时间")
rootCmd.AddCommand(downloadCmd)
curlCmd.DisableFlagParsing = true
rootCmd.AddCommand(curlCmd)
Expand Down
80 changes: 49 additions & 31 deletions download.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type StartDownload_Req struct {
Skip_EXT_X_DISCONTINUITY bool // 跳过 #EXT-X-DISCONTINUITY 标签包裹的ts
DebugLog bool // 调试日志
TsTempDir string // 临时ts文件目录
UseServerSideTime bool // 使用服务端提供的文件时间
}

type DownloadEnv struct {
Expand Down Expand Up @@ -117,17 +118,17 @@ func (this *DownloadEnv) getEncryptInfo(m3u8Url string, html string) (info *Encr
if errMsg != "" {
return nil, errors.New(errMsg)
}
var res []byte
var httpCode int
res, httpCode, err = this.doGetRequest(keyUrl, true)
var keyContent []byte
var httpResp *http.Response
keyContent, httpResp, err = this.doGetRequest(keyUrl, true)
if err != nil {
return nil, err
}
if httpCode != 200 {
return nil, errors.New("getEncryptInfo httpCode error " + strconv.Itoa(httpCode))
if httpResp.StatusCode != 200 {
return nil, errors.New("getEncryptInfo httpCode error " + strconv.Itoa(httpResp.StatusCode))
}
if method == EncryptMethod_AES128 && len(res) != 16 { // Aes 128
return nil, errors.New("getEncryptInfo invalid key " + strconv.Quote(string(res)))
if method == EncryptMethod_AES128 && len(keyContent) != 16 { // Aes 128
return nil, errors.New("getEncryptInfo invalid key " + strconv.Quote(string(keyContent)))
}
var iv []byte
ivs := keyPart.KeyValue["IV"]
Expand All @@ -139,7 +140,7 @@ func (this *DownloadEnv) getEncryptInfo(m3u8Url string, html string) (info *Encr
}
return &EncryptInfo{
Method: method,
Key: res,
Key: keyContent,
Iv: iv,
}, nil
}
Expand Down Expand Up @@ -197,26 +198,36 @@ func getTsList(beginSeq uint64, m38uUrl string, body string) (tsList []TsInfo, e

// 下载ts文件
// @modify: 2020-08-13 修复ts格式SyncByte合并不能播放问题
func (this *DownloadEnv) downloadTsFile(ts *TsInfo, skipInfo SkipTsInfo, downloadDir string, encInfo *EncryptInfo) (err error) {
currPath := fmt.Sprintf("%s/%s", downloadDir, ts.Name)
func (this *DownloadEnv) downloadTsFile(ts *TsInfo, skipInfo SkipTsInfo, downloadDir string, encInfo *EncryptInfo, useServerSideTime bool) (err error) {
currPath := filepath.Join(downloadDir, ts.Name)
var stat os.FileInfo
stat, err = os.Stat(currPath)
if err == nil && stat.Mode().IsRegular() {
this.status.SpeedAdd1Block(stat.ModTime(), int(stat.Size()))
return nil
}
beginTime := time.Now()
data, httpCode, err := this.doGetRequest(ts.Url, false)
data, httpResp, err := this.doGetRequest(ts.Url, false)
if err != nil {
return err
}
if len(skipInfo.HttpCodeList) > 0 && isInIntSlice(httpCode, skipInfo.HttpCodeList) {
if len(skipInfo.HttpCodeList) > 0 && isInIntSlice(httpResp.StatusCode, skipInfo.HttpCodeList) {
this.status.SpeedAdd1Block(beginTime, 0)
ts.SkipByHttpCode = true
ts.HttpCode = httpCode
this.logToFile("skip ts " + strconv.Quote(ts.Name) + " byHttpCode: " + strconv.Itoa(httpCode))
ts.HttpCode = httpResp.StatusCode
this.logToFile("skip ts " + strconv.Quote(ts.Name) + " byHttpCode: " + strconv.Itoa(httpResp.StatusCode))
return nil
}
var mTime time.Time
if mStr := httpResp.Header.Get("Last-Modified"); mStr != "" && useServerSideTime {
this.logToFile("get mtime " + strconv.Quote(mStr))
mTime, err = time.Parse(time.RFC1123, mStr)
// 这个错误不重要, 所以只记录日志
if err != nil {
this.logToFile("parse mtime error " + err.Error())
mTime = time.Time{}
}
}
// 校验长度是否合法
var origData []byte
origData = data
Expand Down Expand Up @@ -254,6 +265,13 @@ func (this *DownloadEnv) downloadTsFile(ts *TsInfo, skipInfo SkipTsInfo, downloa
if err != nil {
return err
}
if mTime.IsZero() == false {
err = os.Chtimes(currPath, mTime, mTime)
// 这个错误不重要, 所以只记录日志
if err != nil {
this.logToFile("os.Chtimes error " + err.Error())
}
}
this.status.SpeedAdd1Block(beginTime, len(origData))
return nil
}
Expand All @@ -274,11 +292,11 @@ func (this *DownloadEnv) SleepDur(d time.Duration) {
}
}

func (this *DownloadEnv) downloader(tsList []TsInfo, skipInfo SkipTsInfo, downloadDir string, encInfo *EncryptInfo, threadCount int) (err error) {
if threadCount <= 0 || threadCount > 1000 {
return errors.New("DownloadEnv.threadCount invalid: " + strconv.Itoa(threadCount))
func (this *DownloadEnv) downloader(tsList []TsInfo, skipInfo SkipTsInfo, downloadDir string, encInfo *EncryptInfo, req StartDownload_Req) (err error) {
if req.ThreadCount <= 0 || req.ThreadCount > 1000 {
return errors.New("DownloadEnv.threadCount invalid: " + strconv.Itoa(req.ThreadCount))
}
task := gopool.NewThreadPool(threadCount)
task := gopool.NewThreadPool(req.ThreadCount)
var locker sync.Mutex

this.status.ResetTotalBlockCount(len(tsList))
Expand All @@ -299,7 +317,7 @@ func (this *DownloadEnv) downloader(tsList []TsInfo, skipInfo SkipTsInfo, downlo
this.SleepDur(time.Second * time.Duration(i))
atomic.AddInt32(&this.sleepTh, -1)
}
lastErr = this.downloadTsFile(ts, skipInfo, downloadDir, encInfo)
lastErr = this.downloadTsFile(ts, skipInfo, downloadDir, encInfo, req.UseServerSideTime)
if lastErr == nil {
break
}
Expand Down Expand Up @@ -364,13 +382,13 @@ func AesDecrypt(seq uint64, encrypted []byte, encInfo *EncryptInfo) ([]byte, err
func (this *DownloadEnv) sniffM3u8(urlS string) (afterUrl string, content []byte, errMsg string) {
for idx := 0; idx < 5; idx++ {
var err error
var httpCode int
content, httpCode, err = this.doGetRequest(urlS, true)
var httpResp *http.Response
content, httpResp, err = this.doGetRequest(urlS, true)
if err != nil {
return "", nil, err.Error()
}
if httpCode != 200 {
return "", nil, "invalid httpCode " + strconv.Itoa(httpCode)
if httpResp.StatusCode != 200 {
return "", nil, "invalid httpCode " + strconv.Itoa(httpResp.StatusCode)
}

if UrlHasSuffix(urlS, ".m3u8") {
Expand Down Expand Up @@ -436,10 +454,10 @@ func UrlHasSuffix(urlS string, suff string) bool {
return strings.HasSuffix(strings.ToLower(urlObj.Path), suff)
}

func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []byte, statusCode int, err error) {
func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []byte, resp *http.Response, err error) {
req, err := http.NewRequest(http.MethodGet, urlS, nil)
if err != nil {
return nil, 0, err
return nil, nil, err
}
req = req.WithContext(this.ctx)
req.Header = this.header
Expand All @@ -457,7 +475,7 @@ func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []by

beginTime := time.Now()

resp, err := this.nowClient.Do(req)
resp, err = this.nowClient.Do(req)
if logBuf != nil && resp != nil {
respBytes, _ := httputil.DumpResponse(resp, false)
logBuf.WriteString("httpResp:\n" + string(respBytes) + "\n")
Expand All @@ -470,7 +488,7 @@ func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []by
logBuf.WriteString("error1:" + err.Error() + "\n")
this.logToFile(logBuf.String())
}
return nil, 0, err
return nil, nil, err
}
defer resp.Body.Close()

Expand All @@ -487,7 +505,7 @@ func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []by
logBuf.WriteString(err.Error() + "\n")
this.logToFile(logBuf.String())
}
return nil, 0, err
return nil, nil, err
}
defer readCloser.Close()
case "deflate":
Expand All @@ -501,7 +519,7 @@ func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []by
logBuf.WriteString(err.Error() + "\n")
this.logToFile(logBuf.String())
}
return nil, 0, err
return nil, nil, err
}
content, err = this.status.SpeedReadAll(readCloser)
if logBuf != nil {
Expand All @@ -512,15 +530,15 @@ func (this *DownloadEnv) doGetRequest(urlS string, dumpRespBody bool) (data []by
logBuf.WriteString("error4:" + err.Error() + "\n")
this.logToFile(logBuf.String())
}
return nil, 0, err
return nil, nil, err
}
if logBuf != nil && dumpRespBody {
logBuf.WriteString("httpRespBody:\n" + string(content))
}
if logBuf != nil {
this.logToFile(logBuf.String())
}
return content, resp.StatusCode, nil
return content, resp, nil
}

func (this *DownloadEnv) logToFile(body string) {
Expand Down
Loading

0 comments on commit f0a0801

Please sign in to comment.