diff --git a/pkg/process/v6/writer.go b/pkg/process/v6/writer.go index 3c77b4dc..0e5328e9 100644 --- a/pkg/process/v6/writer.go +++ b/pkg/process/v6/writer.go @@ -127,8 +127,8 @@ func (w *writer) writeEntry(entry transformers.RelatedEntries) error { } // fillInMissingSeverity will add a severity entry to the vulnerability record if it is missing, empty, or "unknown". -// The upstream NVD record is used to fill in these missing values. Note that the NVD provider is always guaranteed -// to be processed first before other providers. +// The upstream NVD record is used to fill in these missing values for CVEs, and GitHub records are used for GHSAs. +// Note that the NVD provider is always processed first, then GitHub, then other providers. func (w *writer) fillInMissingSeverity(handle *grypeDB.VulnerabilityHandle) { if handle == nil { return @@ -141,14 +141,26 @@ func (w *writer) fillInMissingSeverity(handle *grypeDB.VulnerabilityHandle) { id := strings.ToLower(blob.ID) isCVE := strings.HasPrefix(id, "cve-") - if strings.ToLower(handle.ProviderID) == "nvd" && isCVE { + isGHSA := strings.HasPrefix(id, "ghsa-") + + // Cache severity data from NVD (for CVEs) and GitHub (for GHSAs) + providerID := strings.ToLower(handle.ProviderID) + if providerID == "nvd" && isCVE { if len(blob.Severities) > 0 { w.severityCache[id] = blob.Severities[0] } return } - if !isCVE { + if providerID == "github" && isGHSA { + if len(blob.Severities) > 0 { + w.severityCache[id] = blob.Severities[0] + } + return + } + + // Only process CVEs and GHSAs for backfilling + if !isCVE && !isGHSA { return } @@ -171,15 +183,23 @@ func (w *writer) fillInMissingSeverity(handle *grypeDB.VulnerabilityHandle) { return // already has a severity, don't normalize } - // add the top NVD severity value - nvdSev, ok := w.severityCache[id] + // Look up cached severity data (NVD for CVEs, GitHub for GHSAs) + cachedSev, ok := w.severityCache[id] if !ok { - log.WithFields("id", blob.ID).Trace("unable to find NVD severity") + sourceType := "NVD" + if isGHSA { + sourceType = "GitHub" + } + log.WithFields("id", blob.ID).Trace("unable to find " + sourceType + " severity") return } - log.WithFields("id", blob.ID, "provider", handle.Provider, "sev-from", topSevStr, "sev-to", nvdSev).Trace("overriding irrelevant severity with data from NVD record") - sevs = append([]grypeDB.Severity{nvdSev}, sevs...) + sourceType := "NVD" + if isGHSA { + sourceType = "GitHub" + } + log.WithFields("id", blob.ID, "provider", handle.Provider, "sev-from", topSevStr, "sev-to", cachedSev).Trace("overriding irrelevant severity with data from " + sourceType + " record") + sevs = append([]grypeDB.Severity{cachedSev}, sevs...) handle.BlobValue.Severities = sevs } diff --git a/pkg/process/v6/writer_test.go b/pkg/process/v6/writer_test.go index bd9cdd05..5e5000f3 100644 --- a/pkg/process/v6/writer_test.go +++ b/pkg/process/v6/writer_test.go @@ -32,10 +32,10 @@ func TestFillInMissingSeverity(t *testing.T) { expected: nil, }, { - name: "non-CVE ID", + name: "non-CVE/non-GHSA ID", handle: &grypeDB.VulnerabilityHandle{ BlobValue: &grypeDB.VulnerabilityBlob{ - ID: "GHSA-123", + ID: "OTHER-123", Severities: []grypeDB.Severity{ {Value: "high"}, }, @@ -59,6 +59,21 @@ func TestFillInMissingSeverity(t *testing.T) { expected: []grypeDB.Severity{{Value: "critical"}}, expectCacheUpdate: true, }, + { + name: "GitHub provider with GHSA", + handle: &grypeDB.VulnerabilityHandle{ + ProviderID: "github", + BlobValue: &grypeDB.VulnerabilityBlob{ + ID: "GHSA-1234-5678-9abc", + Severities: []grypeDB.Severity{ + {Value: "high"}, + }, + }, + }, + severityCache: map[string]grypeDB.Severity{}, + expected: []grypeDB.Severity{{Value: "high"}}, + expectCacheUpdate: true, + }, { name: "CVE with existing severities", handle: &grypeDB.VulnerabilityHandle{ @@ -93,6 +108,20 @@ func TestFillInMissingSeverity(t *testing.T) { }, expected: []grypeDB.Severity{{Value: "high"}}, }, + { + name: "GHSA with no severities, using cache", + handle: &grypeDB.VulnerabilityHandle{ + ProviderID: "alpine", + BlobValue: &grypeDB.VulnerabilityBlob{ + ID: "GHSA-abcd-efgh-ijkl", + Severities: []grypeDB.Severity{}, + }, + }, + severityCache: map[string]grypeDB.Severity{ + "ghsa-abcd-efgh-ijkl": {Value: "medium"}, + }, + expected: []grypeDB.Severity{{Value: "medium"}}, + }, } for _, tt := range tests { diff --git a/pkg/provider/providers/providers.go b/pkg/provider/providers/providers.go index 301d6a22..26a98193 100644 --- a/pkg/provider/providers/providers.go +++ b/pkg/provider/providers/providers.go @@ -2,6 +2,7 @@ package providers import ( "fmt" + "sort" "github.com/mitchellh/mapstructure" @@ -32,14 +33,32 @@ func New(root string, vCfg vunnel.Config, cfgs ...provider.Config) (provider.Pro if err != nil { return nil, err } - if p.ID().Name == "nvd" { - // it is important that NVD is processed first since other providers depend on the severity information from these records - providers = append([]provider.Provider{p}, providers...) - } else { - providers = append(providers, p) - } + providers = append(providers, p) } + sort.SliceStable(providers, func(i, j int) bool { + nameI, nameJ := providers[i].ID().Name, providers[j].ID().Name + + // NVD always comes first + if nameI == "nvd" { + return true + } + if nameJ == "nvd" { + return false + } + + // GitHub comes second (after NVD) + if nameI == "github" { + return true + } + if nameJ == "github" { + return false + } + + // All others maintain original order + return false + }) + return providers, nil } diff --git a/pkg/provider/providers/providers_test.go b/pkg/provider/providers/providers_test.go new file mode 100644 index 00000000..86350a6a --- /dev/null +++ b/pkg/provider/providers/providers_test.go @@ -0,0 +1,118 @@ +package providers + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/anchore/grype-db/pkg/provider" + "github.com/anchore/grype-db/pkg/provider/providers/vunnel" +) + +// mockProvider is a simple mock implementation of provider.Provider for testing +type mockProvider struct { + name string +} + +func (m mockProvider) ID() provider.Identifier { + return provider.Identifier{Name: m.name} +} + +func (m mockProvider) GetProviderState() provider.State { + return provider.State{} +} + +func (m mockProvider) GetSchemaVersion() int { + return 1 +} + +func (m mockProvider) GetWorkspace() string { + return "" +} + +func (m mockProvider) Close() error { + return nil +} + +func TestNew_ProviderOrdering(t *testing.T) { + tests := []struct { + name string + providerNames []string + expectedOrdering []string + }{ + { + name: "nvd first, github second", + providerNames: []string{"other1", "github", "nvd", "other2"}, + expectedOrdering: []string{"nvd", "github", "other1", "other2"}, + }, + { + name: "only nvd", + providerNames: []string{"nvd"}, + expectedOrdering: []string{"nvd"}, + }, + { + name: "only github", + providerNames: []string{"github"}, + expectedOrdering: []string{"github"}, + }, + { + name: "no nvd or github", + providerNames: []string{"other1", "other2", "other3"}, + expectedOrdering: []string{"other1", "other2", "other3"}, + }, + { + name: "nvd and github only", + providerNames: []string{"github", "nvd"}, + expectedOrdering: []string{"nvd", "github"}, + }, + { + name: "multiple others with nvd and github", + providerNames: []string{"alpine", "github", "debian", "nvd", "ubuntu", "centos"}, + expectedOrdering: []string{"nvd", "github", "alpine", "debian", "ubuntu", "centos"}, + }, + { + name: "reverse alphabetical input", + providerNames: []string{"ubuntu", "nvd", "github", "debian", "alpine"}, + expectedOrdering: []string{"nvd", "github", "ubuntu", "debian", "alpine"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock configs for each provider name + var configs []provider.Config + for _, name := range tt.providerNames { + configs = append(configs, provider.Config{ + Identifier: provider.Identifier{ + Name: name, + Kind: provider.VunnelKind, + }, + }) + } + + // Call New with mock vunnel config + vCfg := vunnel.Config{ + GenerateConfigs: false, // Don't generate additional configs for this test + } + + providers, err := New("test-root", vCfg, configs...) + require.NoError(t, err) + require.Len(t, providers, len(tt.expectedOrdering)) + + // Verify the ordering + var actualOrdering []string + for _, p := range providers { + actualOrdering = append(actualOrdering, p.ID().Name) + } + + require.Equal(t, tt.expectedOrdering, actualOrdering) + }) + } +} + +func TestNew_EmptyConfigs(t *testing.T) { + vCfg := vunnel.Config{GenerateConfigs: false} + _, err := New("test-root", vCfg) + require.Error(t, err) + require.Equal(t, ErrNoProviders, err) +}