Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type FlagsConfig struct {
HMACSecretName string

CacheSize int
CacheTTL int

VaultAddr string
VaultMount string
Expand Down
40 changes: 32 additions & 8 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"strings"
"time"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-csi-provider/internal/auth"
Expand All @@ -27,18 +28,25 @@ import (
// and communicates with the Vault API.
type provider struct {
logger hclog.Logger
vaultResponseCache map[vaultResponseCacheKey]*api.Secret
vaultResponseCache map[vaultResponseCacheKey]*cachedSecretEntry
cacheTTL time.Duration

// Allows mocking Kubernetes API for tests.
authMethod *auth.KubernetesJWTAuth
hmacGenerator *hmacgen.HMACGenerator
clientCache *clientcache.ClientCache
}

func NewProvider(logger hclog.Logger, authMethod *auth.KubernetesJWTAuth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache) *provider {
type cachedSecretEntry struct {
Secret *api.Secret
CachedAt time.Time
}

func NewProvider(logger hclog.Logger, authMethod *auth.KubernetesJWTAuth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache, cacheTTL time.Duration) *provider {
p := &provider{
logger: logger,
vaultResponseCache: make(map[vaultResponseCacheKey]*api.Secret),
vaultResponseCache: make(map[vaultResponseCacheKey]*cachedSecretEntry),
cacheTTL: cacheTTL,

authMethod: authMethod,
hmacGenerator: hmacGenerator,
Expand Down Expand Up @@ -106,12 +114,12 @@ func decodeValue(data []byte, encoding string) ([]byte, error) {
}

func (p *provider) getSecret(ctx context.Context, client *vaultclient.Client, secretConfig config.Secret) ([]byte, error) {
var secret *api.Secret
var cached bool
key := vaultResponseCacheKey{secretPath: secretConfig.SecretPath, method: secretConfig.Method}
if secret, cached = p.vaultResponseCache[key]; !cached {
if cachedSecret, cached := p.vaultResponseCache[key]; !cached || isExpired(cachedSecret, p.cacheTTL) {
p.logger.Debug("Secret not cached or TTL expired", "secretConfig", secretConfig)

var err error
secret, err = client.RequestSecret(ctx, p.authMethod, secretConfig)
secret, err := client.RequestSecret(ctx, p.authMethod, secretConfig)
if err != nil {
return nil, fmt.Errorf("couldn't read secret %q: %w", secretConfig.ObjectName, err)
}
Expand All @@ -123,11 +131,16 @@ func (p *provider) getSecret(ctx context.Context, client *vaultclient.Client, se
p.logger.Warn("Warning in response from Vault API", "warning", w)
}

p.vaultResponseCache[key] = secret
p.vaultResponseCache[key] = &cachedSecretEntry{
Secret: secret,
CachedAt: time.Now(),
}
} else {
p.logger.Debug("Secret fetched from cache", "secretConfig", secretConfig)
}

secret := p.vaultResponseCache[key].Secret

// If no secretKey specified, we return the whole response as a JSON object.
if secretConfig.SecretKey == "" {
content, err := json.Marshal(secret)
Expand Down Expand Up @@ -220,3 +233,14 @@ func generateObjectVersion(secret config.Secret, hmacKey []byte, content []byte)
Version: base64.URLEncoding.EncodeToString(hash.Sum(nil)),
}, nil
}

// isExpired checks if a cached secret has expired.
// Returns 'true' if the time elapsed since CachedAt is greater
// than or equal to the TTL. A TTL equal to or less than zero is treated as 'never expires'.
func isExpired(cachedSecret *cachedSecretEntry, ttl time.Duration) bool {
// TTL equals to 0 or negative indicates cache never expires
if ttl <= 0 {
return false
}
return time.Since(cachedSecret.CachedAt) >= ttl
}
4 changes: 2 additions & 2 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func TestHandleMountRequest(t *testing.T) {
clientCache, err := clientcache.NewClientCache(hclog.Default(), 10)
require.NoError(t, err)
// While we hit the cache, the secret contents and versions should remain the same.
provider := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache)
provider := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache, 0)
for i := 0; i < 3; i++ {
resp, err := provider.HandleMountRequest(context.Background(), spcConfig, flagsConfig)
require.NoError(t, err)
Expand All @@ -252,7 +252,7 @@ func TestHandleMountRequest(t *testing.T) {

// The mockVaultHandler function below includes a dynamic counter in the content of secrets.
// That means mounting again with a fresh provider will update the contents of the secrets, which should update the version.
resp, err := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache).HandleMountRequest(context.Background(), spcConfig, flagsConfig)
resp, err := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache, 0).HandleMountRequest(context.Background(), spcConfig, flagsConfig)
require.NoError(t, err)

assert.Equal(t, (*v1alpha1.Error)(nil), resp.Error)
Expand Down
4 changes: 3 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package server
import (
"context"
"fmt"
"time"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-csi-provider/internal/auth"
Expand Down Expand Up @@ -54,7 +55,8 @@ func (s *Server) Mount(ctx context.Context, req *pb.MountRequest) (*pb.MountResp
}

authMethod := auth.NewKubernetesJWTAuth(s.logger.Named("auth"), s.k8sClient, cfg.Parameters, s.flagsConfig.VaultMount)
provider := provider.NewProvider(s.logger.Named("provider"), authMethod, s.hmacGenerator, s.clientCache)
cacheTTL := time.Duration(s.flagsConfig.CacheTTL) * time.Second
provider := provider.NewProvider(s.logger.Named("provider"), authMethod, s.hmacGenerator, s.clientCache, cacheTTL)
resp, err := provider.HandleMountRequest(ctx, cfg, s.flagsConfig)
if err != nil {
return nil, fmt.Errorf("error making mount request: %w", err)
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func realMain(logger hclog.Logger) error {
flag.StringVar(&flags.HMACSecretName, "hmac-secret-name", "vault-csi-provider-hmac-key", "Configure the Kubernetes secret name that the provider creates to store an HMAC key for generating secret version hashes")

flag.IntVar(&flags.CacheSize, "cache-size", 1000, "Set the maximum number of Vault tokens that will be cached in-memory. One Vault token will be stored for each pod on the same node that mounts secrets.")
flag.IntVar(&flags.CacheTTL, "cache-ttl", 0, "TTL for cached secrets, specified in seconds. Defaults to 0 (never expires).")

flag.StringVar(&flags.VaultAddr, "vault-addr", "", "Default address for connecting to Vault. Can also be specified via the VAULT_ADDR environment variable.")
flag.StringVar(&flags.VaultMount, "vault-mount", "kubernetes", "Default Vault mount path for authentication. Can refer to a Kubernetes or JWT auth mount.")
Expand Down
Loading