Skip to content

Commit

Permalink
Retreives vulnerability explainations from OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
djschleen authored Mar 7, 2024
1 parent 023d25b commit a00baee
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 5 deletions.
8 changes: 8 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Debug File (ossindex - railsgoat - AI Output)",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"args": ["--provider=ossindex", "--debug=true", "--output=ai", "scan", "./_TESTDATA_/sbom/railsgoat.cyclonedx.json"]
},
{
"name": "Debug Folder (OSV)",
"type": "go",
Expand Down
Binary file added __debug_bin1527502190
Binary file not shown.
Binary file added __debug_bin1895705154
Binary file not shown.
3 changes: 3 additions & 0 deletions enrichers/enrichmentfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"github.com/devops-kung-fu/bomber/enrichers/epss"
"github.com/devops-kung-fu/bomber/enrichers/openai"
"github.com/devops-kung-fu/bomber/models"
)

Expand All @@ -13,6 +14,8 @@ func NewEnricher(name string) (enricher models.Enricher, err error) {
switch name {
case "epss":
enricher = epss.Enricher{}
case "openai":
enricher = openai.Enricher{}
default:

err = fmt.Errorf("%s is not a valid provider type", name)
Expand Down
91 changes: 90 additions & 1 deletion enrichers/openai/openai.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,102 @@
// package openai enriches vulnerability information
package openai

import "github.com/devops-kung-fu/bomber/models"
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"text/template"

openai "github.com/sashabaranov/go-openai"

"github.com/devops-kung-fu/bomber/models"
)

// Provider represents the openai enricher
type Enricher struct{}

// Enrich adds additional information to vulnerabilities
func (Enricher) Enrich(vulnerabilities []models.Vulnerability, credentials *models.Credentials) ([]models.Vulnerability, error) {
if err := validateCredentials(credentials); err != nil {
return nil, fmt.Errorf("could not validate openai credentials: %w", err)
}

for _, v := range vulnerabilities {
fetch(v, credentials)
log.Println(v.Explanation)
}
return nil, nil
}

func validateCredentials(credentials *models.Credentials) (err error) {
if credentials == nil {
return errors.New("credentials cannot be nil")
}

if credentials.OpenAIAPIKey == "" {
credentials.OpenAIAPIKey = os.Getenv("OPENAI_API_KEY")
}

if credentials.OpenAIAPIKey == "" {
err = errors.New("bomber requires an openai key to enrich vulnerability data")
}
return
}

func fetch(vulnerability models.Vulnerability, credentials *models.Credentials) {
prompt := generatePrompt(vulnerability)
client := openai.NewClient(credentials.OpenAIAPIKey)
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
},
)

if err != nil {
log.Printf("ChatCompletion error: %v\n", err) //TODO: Need to pass the error back up the stack
return
}

vulnerability.Explanation = resp.Choices[0].Message.Content
log.Println(vulnerability.Explanation)

}

func generatePrompt(vulnerability models.Vulnerability) (prompt string) {

promptTemplate := `
Explain what {{ .Cve }} is and dig into: {{ .Description }} so it could be understood by a non-technical business user.
`
// Create a new template with a name
tmpl, err := template.New("prompt").Parse(promptTemplate)
if err != nil {
panic(err)
}

// Create a buffer to store the generated result
var resultBuffer bytes.Buffer

// Execute the template and write the result to the buffer
err = executeTemplate(&resultBuffer, tmpl, vulnerability)
if err != nil {
panic(err)
}

// Convert the buffer to a string and return it
return resultBuffer.String()
}

func executeTemplate(buffer *bytes.Buffer, tmpl *template.Template, data interface{}) error {
// Execute the template and write the result to the buffer
return tmpl.Execute(buffer, data)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sashabaranov/go-openai v1.20.2
github.com/spf13/pflag v1.0.5 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/net v0.22.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.20.2 h1:nilzF2EKzaHyK4Rk2Dbu/aJEZbtIvskDIXvfS4yx+6M=
github.com/sashabaranov/go-openai v1.20.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
Expand Down
10 changes: 6 additions & 4 deletions lib/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,17 @@ func (s *Scanner) filterVulnerabilities(response []models.Package) {

// enrichAndIgnoreVulnerabilities enriches and ignores vulnerabilities as needed.
func (s *Scanner) enrichAndIgnoreVulnerabilities(response []models.Package, ignoredCVE []string) {
enricher, _ := enrichers.NewEnricher("epss")
epssEnricher, _ := enrichers.NewEnricher("epss")
openaiEnricher, _ := enrichers.NewEnricher("openai")
for i, p := range response {
enrichedVulnerabilities, _ := enricher.Enrich(p.Vulnerabilities, &s.Credentials)
response[i].Vulnerabilities = enrichedVulnerabilities

if len(ignoredCVE) > 0 {
filteredVulnerabilities := filters.Ignore(p.Vulnerabilities, ignoredCVE)
response[i].Vulnerabilities = filteredVulnerabilities
}

enrichedVulnerabilities, _ := epssEnricher.Enrich(p.Vulnerabilities, &s.Credentials)
aienrichedVulnerabilities, _ := openaiEnricher.Enrich(enrichedVulnerabilities, &s.Credentials)
response[i].Vulnerabilities = aienrichedVulnerabilities
}
}

Expand Down
1 change: 1 addition & 0 deletions models/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Vulnerability struct {
DisplayName string `json:"displayName,omitempty"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Explanation string `json:"explanation,omitempty"` //This is an enrichment via OpenAI
CvssScore float64 `json:"cvssScore,omitempty"`
CvssVector string `json:"cvssVector,omitempty"`
Cwe string `json:"cwe,omitempty"`
Expand Down

0 comments on commit a00baee

Please sign in to comment.