Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a serve subcommand that will start a gRPC server #83

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions buf.gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
version: v2
managed:
enabled: true
override:
- file_option: go_package_prefix
value: github.com/dennis-tra/nebula-crawler/proto
plugins:
- remote: buf.build/protocolbuffers/go
out: proto
opt: paths=source_relative
- remote: buf.build/connectrpc/go
out: proto
opt: paths=source_relative
inputs:
- directory: proto
10 changes: 10 additions & 0 deletions buf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
version: v2
modules:
- path: proto
lint:
use:
- STANDARD
breaking:
use:
- FILE
1 change: 1 addition & 0 deletions cmd/nebula/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func main() {
ResolveCommand,
NetworksCommand,
HealthCommand,
ServeCommand,
},
}

Expand Down
185 changes: 185 additions & 0 deletions cmd/nebula/cmd_serve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package main

import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"

"connectrpc.com/connect"
"github.com/dennis-tra/nebula-crawler/config"
v1 "github.com/dennis-tra/nebula-crawler/proto/nebula/v1"
"github.com/dennis-tra/nebula-crawler/proto/nebula/v1/nebulav1connect"
log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"

"github.com/dennis-tra/nebula-crawler/db"
)

var serveConfig = &config.Serve{
Root: rootConfig,
Host: "localhost",
Port: 8080,
}

// ServeCommand .
var ServeCommand = &cli.Command{
Name: "serve",
Usage: "Serves data from a Nebula database",
Action: ServeAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "host",
Usage: "Let the server listen on the specified host",
EnvVars: []string{"NEBULA_SERVE_HOST"},
Value: serveConfig.Host,
Destination: &serveConfig.Host,
},
&cli.IntFlag{
Name: "port",
Usage: "Let the server listen on the specified port",
EnvVars: []string{"NEBULA_SERVE_PORT"},
Value: serveConfig.Port,
Destination: &serveConfig.Port,
},
},
}

// ServeAction is the function that is called when running `nebula resolve`.
func ServeAction(c *cli.Context) error {
log.Infoln("Start serving Nebula data...")
defer log.Infoln("Stopped serving Nebula data.")

ctx := c.Context

// initialize a new database client based on the given configuration.
// Options are Postgres, JSON, and noop (dry-run).
dbc, err := db.NewServerClient(ctx, rootConfig.Database)
if err != nil {
return fmt.Errorf("new database client: %w", err)
}
defer func() {
if err := dbc.Close(); err != nil && !errors.Is(err, sql.ErrConnDone) && !strings.Contains(err.Error(), "use of closed network connection") {
log.WithError(err).Warnln("Failed closing database handle")
}
}()

mux := http.NewServeMux()
path, handler := nebulav1connect.NewNebulaServiceHandler(&nebulaServiceServer{
dbc: dbc,
})
mux.Handle(path, handler)

address := fmt.Sprintf("%s:%d", serveConfig.Host, serveConfig.Port)

s := http.Server{
Addr: address,
Handler: h2c.NewHandler(mux, &http2.Server{}),
}

done := make(chan struct{})
go func() {
defer close(done)
log.WithField("addr", address).Infoln("Start listening...")
if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.WithError(err).WithField("addr", address).Error("Failed to serve gRPC server")
}
}()

select {
case <-done:
case <-c.Context.Done():
}

shutdownTimeout := 30 * time.Second
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

log.WithField("timeout", shutdownTimeout).WithField("addr", address).Infoln("Shutting down...")
if err := s.Shutdown(shutdownCtx); err != nil {
log.WithError(err).Error("Failed to shutdown gRPC server")
}

return nil
}

// petStoreServiceServer implements the PetStoreService API.
type nebulaServiceServer struct {
dbc db.ServerClient
}

var _ nebulav1connect.NebulaServiceHandler = (*nebulaServiceServer)(nil)

func (n *nebulaServiceServer) GetPeer(ctx context.Context, c *connect.Request[v1.GetPeerRequest]) (*connect.Response[v1.GetPeerResponse], error) {
log.WithField("multihash", c.Msg.MultiHash).Info("GetPeer")

dbPeer, dbProtocols, err := n.dbc.GetPeer(ctx, c.Msg.MultiHash)
if err != nil {
return nil, err
}

v1Maddrs := make([]*v1.MultiAddress, 0, len(dbPeer.R.MultiAddresses))
for _, dbMaddr := range dbPeer.R.MultiAddresses {
var asn *int32
if !dbMaddr.Asn.IsZero() {
val := int32(dbMaddr.Asn.Int)
asn = &val
}

var isCloud *int32
if !dbMaddr.IsCloud.IsZero() {
val := int32(dbMaddr.IsCloud.Int)
asn = &val
}

var country *string
if !dbMaddr.Country.IsZero() {
country = &dbMaddr.Country.String
}

var continent *string
if !dbMaddr.Continent.IsZero() {
continent = &dbMaddr.Country.String
}

var ip *string
if !dbMaddr.Addr.IsZero() {
ip = &dbMaddr.Addr.String
}

v1Maddrs = append(v1Maddrs, &v1.MultiAddress{
MultiAddress: dbMaddr.Maddr,
Asn: asn,
IsCloud: isCloud,
Country: country,
Continent: continent,
Ip: ip,
})
}

protocols := make([]string, 0, len(dbProtocols))
for _, dbProtocol := range dbProtocols {
protocols = append(protocols, dbProtocol.Protocol)
}

var av *string
if dbPeer.R.AgentVersion != nil {
av = &dbPeer.R.AgentVersion.AgentVersion
}

resp := &connect.Response[v1.GetPeerResponse]{
Msg: &v1.GetPeerResponse{
MultiHash: dbPeer.MultiHash,
AgentVersion: av,
MultiAddresses: v1Maddrs,
Protocols: protocols,
},
}

return resp, nil
}
16 changes: 16 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,19 @@ type Resolve struct {
FilePathMaxmindCountry string
FilePathMaxmindASN string
}

type Serve struct {
Root *Root

// the network interfaces that the server should to bind to
Host string

// the port that the server should bind to
Port int
}

// String prints the configuration as a json string
func (m *Serve) String() string {
data, _ := json.MarshalIndent(m, "", " ")
return string(data)
}
27 changes: 27 additions & 0 deletions db/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ type Client interface {
PersistNeighbors(ctx context.Context, crawl *models.Crawl, dbPeerID *int, peerID peer.ID, errorBits uint16, dbNeighborsIDs []int, neighbors []peer.ID) error
}

type ServerClient interface {
io.Closer
GetPeer(ctx context.Context, multiHash string) (*models.Peer, models.ProtocolSlice, error)
}

// NewClient will initialize the right database client based on the given
// configuration. This can either be a Postgres, JSON, or noop client. The noop
// client is a dummy implementation of the [Client] interface that does nothing
Expand Down Expand Up @@ -53,3 +58,25 @@ func NewClient(ctx context.Context, cfg *config.Database) (Client, error) {

return dbc, nil
}

func NewServerClient(ctx context.Context, cfg *config.Database) (ServerClient, error) {
var (
dbc ServerClient
err error
)

// dry run has precedence. Then, if a JSON output directory is given, use
// the JSON client. In any other case, use the Postgres database client.
if cfg.DryRun {
return nil, fmt.Errorf("server client not implemented")
} else if cfg.JSONOut != "" {
return nil, fmt.Errorf("server client not implemented")
} else {
dbc, err = InitDBServerClient(ctx, cfg)
}
if err != nil {
return nil, fmt.Errorf("init db client: %w", err)
}

return dbc, nil
}
105 changes: 105 additions & 0 deletions db/client_db_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package db

import (
"context"
"database/sql"
"fmt"

"github.com/dennis-tra/nebula-crawler/db/models"
log "github.com/sirupsen/logrus"
"github.com/uptrace/opentelemetry-go-extra/otelsql"
"github.com/volatiletech/sqlboiler/v4/queries/qm"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"

"github.com/dennis-tra/nebula-crawler/config"
)

type DBServerClient struct {
ctx context.Context

// Reference to the configuration
cfg *config.Database

// Database handler
dbh *sql.DB

// reference to all relevant db telemetry
telemetry *telemetry
}

var _ ServerClient = (*DBServerClient)(nil)

// InitDBServerClient establishes a database connection with the provided
// configuration
func InitDBServerClient(ctx context.Context, cfg *config.Database) (*DBServerClient, error) {
log.WithFields(log.Fields{
"host": cfg.DatabaseHost,
"port": cfg.DatabasePort,
"name": cfg.DatabaseName,
"user": cfg.DatabaseUser,
"ssl": cfg.DatabaseSSLMode,
}).Infoln("Initializing database client")

dbh, err := otelsql.Open("postgres", cfg.DatabaseSourceName(),
otelsql.WithAttributes(semconv.DBSystemPostgreSQL),
otelsql.WithMeterProvider(cfg.MeterProvider),
otelsql.WithTracerProvider(cfg.TracerProvider),
)
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}

// Set to match the writer worker
dbh.SetMaxIdleConns(cfg.MaxIdleConns) // default is 2 which leads to many connection open/closings

otelsql.ReportDBStatsMetrics(dbh, otelsql.WithMeterProvider(cfg.MeterProvider))

// Ping database to verify connection.
if err = dbh.Ping(); err != nil {
return nil, fmt.Errorf("pinging database: %w", err)
}

telemetry, err := newTelemetry(cfg.TracerProvider, cfg.MeterProvider)
if err != nil {
return nil, fmt.Errorf("new telemetry: %w", err)
}

client := &DBServerClient{ctx: ctx, cfg: cfg, dbh: dbh, telemetry: telemetry}

return client, nil
}

func (d *DBServerClient) Close() error {
return d.dbh.Close()
}

func (d *DBServerClient) GetPeer(ctx context.Context, multiHash string) (*models.Peer, models.ProtocolSlice, error) {
// write a hand-crafted query to avoid two DB round-trips

dbPeer, err := models.Peers(
models.PeerWhere.MultiHash.EQ(multiHash),
qm.Load(models.PeerRels.AgentVersion),
qm.Load(models.PeerRels.MultiAddresses),
qm.Load(models.PeerRels.ProtocolsSet),
).One(ctx, d.dbh)
if err != nil {
return nil, nil, fmt.Errorf("getting peer: %w", err)
}

if dbPeer.R.ProtocolsSet == nil {
return dbPeer, nil, nil
}

protocolIDs := dbPeer.R.ProtocolsSet.ProtocolIds
ids := make([]int, 0, len(protocolIDs))
for _, id := range protocolIDs {
ids = append(ids, int(id))
}

dbProtocols, err := models.Protocols(models.ProtocolWhere.ID.IN(ids)).All(ctx, d.dbh)
if err != nil {
return dbPeer, nil, fmt.Errorf("getting protocols: %w", err)
}

return dbPeer, dbProtocols, nil
}
Loading
Loading