Skip to content

Commit

Permalink
fix: 调整注册中心注入方式
Browse files Browse the repository at this point in the history
  • Loading branch information
lbbniu committed Mar 18, 2023
1 parent 5ec0023 commit 490e6e4
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 296 deletions.
4 changes: 1 addition & 3 deletions examples/PolarisServer/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ func main() {
}
defer provider.Destroy()
// 注册中心
tars.SetRegistry(pr.New(provider, pr.WithNamespace("tars")))

comm := tars.NewCommunicator()
comm := tars.NewCommunicator(tars.WithRegistry(pr.New(provider, pr.WithNamespace("tars"))))
obj := fmt.Sprintf("TestApp.PolarisServer.HelloObj")
app := new(TestApp.HelloObj)
comm.StringToProxy(obj, app)
Expand Down
5 changes: 1 addition & 4 deletions examples/PolarisServer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ func main() {
log.Fatalf("fail to create providerAPI, err is %v", err)
}
defer provider.Destroy()
// 注册中心
tars.SetRegistry(pr.New(provider, pr.WithNamespace("tars")))

// New servant imp
imp := new(HelloObjImp)
err = imp.Init()
Expand All @@ -40,5 +37,5 @@ func main() {
app.AddServantWithContext(imp, cfg.App+"."+cfg.Server+".HelloObj")

// Run application
tars.Run()
tars.Run(tars.WithRegistry(pr.New(provider, pr.WithNamespace("tars"))))
}
255 changes: 14 additions & 241 deletions tars/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ package tars
import (
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/exec"
Expand All @@ -18,14 +15,10 @@ import (
"sync/atomic"
"time"

"github.com/TarsCloud/TarsGo/tars/protocol"
"github.com/TarsCloud/TarsGo/tars/protocol/res/adminf"
"github.com/TarsCloud/TarsGo/tars/transport"
"github.com/TarsCloud/TarsGo/tars/util/conf"
"github.com/TarsCloud/TarsGo/tars/util/endpoint"
"github.com/TarsCloud/TarsGo/tars/util/grace"
"github.com/TarsCloud/TarsGo/tars/util/rogger"
"github.com/TarsCloud/TarsGo/tars/util/ssl"
"github.com/TarsCloud/TarsGo/tars/util/tools"
"go.uber.org/automaxprocs/maxprocs"
)
Expand Down Expand Up @@ -67,16 +60,6 @@ func init() {
rogger.SetLevel(rogger.ERROR)
}

// ServerConfigPath is the path of server config
var ServerConfigPath string
var cnf *conf.Conf

// GetConf Get server conf.Conf config
func GetConf() *conf.Conf {
Init()
return cnf
}

type ServerConfOption func(*transport.TarsServerConf)

func WithQueueCap(queueCap int) ServerConfOption {
Expand Down Expand Up @@ -118,224 +101,13 @@ func newTarsServerConf(proto, address string, svrCfg *serverConfig, opts ...Serv
return svrConf
}

func initConfig() {
defer func() {
go func() {
_ = statInitOnce.Do(initReport)
}()
}()
svrCfg = newServerConfig()
cltCfg = newClientConfig()
if ServerConfigPath == "" {
svrFlag := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
svrFlag.StringVar(&ServerConfigPath, "config", "", "server config path")
svrFlag.Parse(os.Args[1:])
}

if len(ServerConfigPath) == 0 {
return
}

c, err := conf.NewConf(ServerConfigPath)
if err != nil {
TLOG.Errorf("Parse server config fail %v", err)
return
}
cnf = c

// Config.go
// init server config
if strings.EqualFold(c.GetString("/tars/application<enableset>"), "Y") {
svrCfg.Enableset = true
svrCfg.Setdivision = c.GetString("/tars/application<setdivision>")
}
sMap := c.GetMap("/tars/application/server")
svrCfg.Node = sMap["node"]
svrCfg.App = sMap["app"]
svrCfg.Server = sMap["server"]
svrCfg.LocalIP = sMap["localip"]
svrCfg.Local = c.GetString("/tars/application/server<local>")
// svrCfg.Container = c.GetString("/tars/application<container>")

// init log
svrCfg.LogPath = sMap["logpath"]
svrCfg.LogSize = tools.ParseLogSizeMb(sMap["logsize"])
svrCfg.LogNum = tools.ParseLogNum(sMap["lognum"])
svrCfg.LogLevel = sMap["logLevel"]
svrCfg.Config = sMap["config"]
svrCfg.Notify = sMap["notify"]
svrCfg.BasePath = sMap["basepath"]
svrCfg.DataPath = sMap["datapath"]
svrCfg.Log = sMap["log"]

// add version info
svrCfg.Version = Version
// add adapters config
svrCfg.Adapters = make(map[string]adapterConfig)

cachePath := filepath.Join(svrCfg.DataPath, svrCfg.Server) + ".tarsdat"
if cacheData, err := ioutil.ReadFile(cachePath); err == nil {
json.Unmarshal(cacheData, &appCache)
}

if svrCfg.LogLevel == "" {
svrCfg.LogLevel = appCache.LogLevel
} else {
appCache.LogLevel = svrCfg.LogLevel
}
rogger.SetLevel(rogger.StringToLevel(svrCfg.LogLevel))
if svrCfg.LogPath != "" {
TLOG.SetFileRoller(svrCfg.LogPath+"/"+svrCfg.App+"/"+svrCfg.Server, 10, 100)
}

// cache
appCache.TarsVersion = Version

// add timeout config
svrCfg.AcceptTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<accepttimeout>", AcceptTimeout))
svrCfg.ReadTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<readtimeout>", ReadTimeout))
svrCfg.WriteTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<writetimeout>", WriteTimeout))
svrCfg.HandleTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<handletimeout>", HandleTimeout))
svrCfg.IdleTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<idletimeout>", IdleTimeout))
svrCfg.ZombieTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<zombietimeout>", ZombieTimeout))
svrCfg.QueueCap = c.GetIntWithDef("/tars/application/server<queuecap>", QueueCap)
svrCfg.GracedownTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<gracedowntimeout>", GracedownTimeout))

// add tcp config
svrCfg.TCPReadBuffer = c.GetIntWithDef("/tars/application/server<tcpreadbuffer>", TCPReadBuffer)
svrCfg.TCPWriteBuffer = c.GetIntWithDef("/tars/application/server<tcpwritebuffer>", TCPWriteBuffer)
svrCfg.TCPNoDelay = c.GetBoolWithDef("/tars/application/server<tcpnodelay>", TCPNoDelay)
// add routine number
svrCfg.MaxInvoke = c.GetInt32WithDef("/tars/application/server<maxroutine>", MaxInvoke)
// add adapter & report config
svrCfg.PropertyReportInterval = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<propertyreportinterval>", PropertyReportInterval))
svrCfg.StatReportInterval = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<statreportinterval>", StatReportInterval))
svrCfg.MainLoopTicker = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/server<mainloopticker>", MainLoopTicker))
svrCfg.StatReportChannelBufLen = c.GetInt32WithDef("/tars/application/server<statreportchannelbuflen>", StatReportChannelBufLen)
// maxPackageLength
svrCfg.MaxPackageLength = c.GetIntWithDef("/tars/application/server<maxPackageLength>", MaxPackageLength)
protocol.SetMaxPackageLength(svrCfg.MaxPackageLength)
// tls
svrCfg.Key = c.GetString("/tars/application/server<key>")
svrCfg.Cert = c.GetString("/tars/application/server<cert>")
var tlsConfig *tls.Config
if svrCfg.Key != "" && svrCfg.Cert != "" {
svrCfg.CA = c.GetString("/tars/application/server<ca>")
svrCfg.VerifyClient = c.GetStringWithDef("/tars/application/server<verifyclient>", "0") != "0"
svrCfg.Ciphers = c.GetString("/tars/application/server<ciphers>")
tlsConfig, err = ssl.NewServerTlsConfig(svrCfg.CA, svrCfg.Cert, svrCfg.Key, svrCfg.VerifyClient, svrCfg.Ciphers)
if err != nil {
panic(err)
}
}
svrCfg.SampleRate = c.GetFloatWithDef("/tars/application/server<samplerate>", 0)
svrCfg.SampleType = c.GetString("/tars/application/server<sampletype>")
svrCfg.SampleAddress = c.GetString("/tars/application/server<sampleaddress>")
svrCfg.SampleEncoding = c.GetStringWithDef("/tars/application/server<sampleencoding>", "json")

// init client config
cMap := c.GetMap("/tars/application/client")
cltCfg.Locator = cMap["locator"]
cltCfg.Stat = cMap["stat"]
cltCfg.Property = cMap["property"]
cltCfg.ModuleName = cMap["modulename"]
cltCfg.AsyncInvokeTimeout = c.GetIntWithDef("/tars/application/client<async-invoke-timeout>", AsyncInvokeTimeout)
cltCfg.RefreshEndpointInterval = c.GetIntWithDef("/tars/application/client<refresh-endpoint-interval>", refreshEndpointInterval)
cltCfg.ReportInterval = c.GetIntWithDef("/tars/application/client<report-interval>", reportInterval)
cltCfg.CheckStatusInterval = c.GetIntWithDef("/tars/application/client<check-status-interval>", checkStatusInterval)
cltCfg.KeepAliveInterval = c.GetIntWithDef("/tars/application/client<keep-alive-interval>", keepAliveInterval)

// add client timeout
cltCfg.ClientQueueLen = c.GetIntWithDef("/tars/application/client<clientqueuelen>", ClientQueueLen)
cltCfg.ClientIdleTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/client<clientidletimeout>", ClientIdleTimeout))
cltCfg.ClientReadTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/client<clientreadtimeout>", ClientReadTimeout))
cltCfg.ClientWriteTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/client<clientwritetimeout>", ClientWriteTimeout))
cltCfg.ClientDialTimeout = tools.ParseTimeOut(c.GetIntWithDef("/tars/application/client<clientdialtimeout>", ClientDialTimeout))
cltCfg.ReqDefaultTimeout = c.GetInt32WithDef("/tars/application/client<reqdefaulttimeout>", ReqDefaultTimeout)
cltCfg.ObjQueueMax = c.GetInt32WithDef("/tars/application/client<objqueuemax>", ObjQueueMax)
ca := c.GetString("/tars/application/client<ca>")
if ca != "" {
cert := c.GetString("/tars/application/client<cert>")
key := c.GetString("/tars/application/client<key>")
ciphers := c.GetString("/tars/application/client<ciphers>")
clientTlsConfig, err = ssl.NewClientTlsConfig(ca, cert, key, ciphers)
if err != nil {
panic(err)
}
}

serList = c.GetDomain("/tars/application/server")
for _, adapter := range serList {
endString := c.GetString("/tars/application/server/" + adapter + "<endpoint>")
end := endpoint.Parse(endString)
svrObj := c.GetString("/tars/application/server/" + adapter + "<servant>")
proto := c.GetString("/tars/application/server/" + adapter + "<protocol>")
queuecap := c.GetIntWithDef("/tars/application/server/"+adapter+"<queuecap>", svrCfg.QueueCap)
threads := c.GetInt("/tars/application/server/" + adapter + "<threads>")
svrCfg.Adapters[adapter] = adapterConfig{end, proto, svrObj, threads}
host := end.Host
if end.Bind != "" {
host = end.Bind
}
var opts []ServerConfOption
opts = append(opts, WithQueueCap(queuecap))
if end.IsSSL() {
key := c.GetString("/tars/application/server/" + adapter + "<key>")
cert := c.GetString("/tars/application/server/" + adapter + "<cert>")
if key != "" && cert != "" {
ca = c.GetString("/tars/application/server/" + adapter + "<ca>")
verifyClient := c.GetString("/tars/application/server/"+adapter+"<verifyclient>") != "0"
ciphers := c.GetString("/tars/application/server/" + adapter + "<ciphers>")
var adpTlsConfig *tls.Config
adpTlsConfig, err = ssl.NewServerTlsConfig(ca, cert, key, verifyClient, ciphers)
if err != nil {
panic(err)
}
opts = append(opts, WithTlsConfig(adpTlsConfig))
} else {
// common tls.Config
opts = append(opts, WithTlsConfig(tlsConfig))
}
}
tarsConfig[svrObj] = newTarsServerConf(end.Proto, fmt.Sprintf("%s:%d", host, end.Port), svrCfg, opts...)
}
TLOG.Debug("config add ", tarsConfig)

if len(svrCfg.Local) > 0 {
localPoint := endpoint.Parse(svrCfg.Local)
// 管理端口不启动协程池
tarsConfig["AdminObj"] = newTarsServerConf(localPoint.Proto, fmt.Sprintf("%s:%d", localPoint.Host, localPoint.Port), svrCfg, WithMaxInvoke(0))
svrCfg.Adapters["AdminAdapter"] = adapterConfig{localPoint, localPoint.Proto, "AdminObj", 1}
RegisterAdmin(rogger.Admin, rogger.HandleDyeingAdmin)
}

auths := c.GetDomain("/tars/application/client")
for _, objName := range auths {
authInfo := make(map[string]string)
// authInfo["accesskey"] = c.GetString("/tars/application/client/" + objName + "<accesskey>")
// authInfo["secretkey"] = c.GetString("/tars/application/client/" + objName + "<secretkey>")
authInfo["ca"] = c.GetString("/tars/application/client/" + objName + "<ca>")
authInfo["cert"] = c.GetString("/tars/application/client/" + objName + "<cert>")
authInfo["key"] = c.GetString("/tars/application/client/" + objName + "<key>")
authInfo["ciphers"] = c.GetString("/tars/application/client/" + objName + "<ciphers>")
clientObjInfo[objName] = authInfo
if authInfo["ca"] != "" {
var objTlsConfig *tls.Config
objTlsConfig, err = ssl.NewClientTlsConfig(authInfo["ca"], authInfo["cert"], authInfo["key"], authInfo["ciphers"])
if err != nil {
panic(err)
}
clientObjTlsConfig[objName] = objTlsConfig
}
}
}

// Run the application
func Run() {
func Run(opts ...Option) {
defer rogger.FlushLogger()
isShutdowning = 0
Init()
<-statInited
cfg.apply(opts)

for _, env := range os.Environ() {
if strings.HasPrefix(env, grace.InheritFdPrefix) {
Expand Down Expand Up @@ -435,18 +207,18 @@ func graceRestart() {
}

// redirect stdout/stderr to logger
cfg := GetServerConfig()
var logfile *os.File
if cfg != nil {
svrCfg := GetServerConfig()
var logFile *os.File
if svrCfg != nil {
GetLogger("")
logpath := filepath.Join(cfg.LogPath, cfg.App, cfg.Server, cfg.App+"."+cfg.Server+".log")
logfile, _ = os.OpenFile(logpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
TLOG.Debugf("redirect to %s %v", logpath, logfile)
logPath := filepath.Join(svrCfg.LogPath, svrCfg.App, svrCfg.Server, svrCfg.App+"."+svrCfg.Server+".log")
logFile, _ = os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
TLOG.Debugf("redirect to %s %v", logPath, logFile)
}
if logfile == nil {
logfile = os.Stdout
if logFile == nil {
logFile = os.Stdout
}
files := []*os.File{os.Stdin, logfile, logfile}
files := []*os.File{os.Stdin, logFile, logFile}
for key, file := range grace.GetAllListenFiles() {
fd := fmt.Sprint(file.Fd())
newFd := len(files)
Expand Down Expand Up @@ -484,7 +256,7 @@ func graceShutdown() {
// shutdown by admin,we should need shorten the timeout
graceShutdownTimeout = tools.ParseTimeOut(GracedownTimeout)
} else {
graceShutdownTimeout = svrCfg.GracedownTimeout
graceShutdownTimeout = GetServerConfig().GracedownTimeout
}

TLOG.Infof("grace shutdown start %d in %v", pid, graceShutdownTimeout)
Expand Down Expand Up @@ -577,6 +349,7 @@ func mainLoop() {
go registryAdapters(ctx)
loop := time.NewTicker(GetServerConfig().MainLoopTicker)

adapters := GetServerConfig().Adapters
for {
select {
case <-shutdown:
Expand All @@ -586,7 +359,7 @@ func mainLoop() {
if atomic.LoadInt32(&isShutdowning) == 1 {
continue
}
for name, adapter := range svrCfg.Adapters {
for name, adapter := range adapters {
if adapter.Protocol == "not_tars" {
// TODO not_tars support
ha.KeepAlive(name)
Expand Down
3 changes: 2 additions & 1 deletion tars/communicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ func GetCommunicator() *Communicator {

// NewCommunicator returns a new communicator. A Communicator is used for communicating with
// the server side which should only init once and be global!!!
func NewCommunicator() *Communicator {
func NewCommunicator(opts ...Option) *Communicator {
c := new(Communicator)
c.init()
cfg.apply(opts)
return c
}

Expand Down
Loading

0 comments on commit 490e6e4

Please sign in to comment.