diff --git a/internal/config/config.go b/internal/config/config.go index fb4af20f..512e02d3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,6 +37,7 @@ type FlagsConfig struct { HMACSecretName string CacheSize int + CacheTTL int VaultAddr string VaultMount string diff --git a/internal/provider/provider.go b/internal/provider/provider.go index bfc18bca..ad2b8986 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault-csi-provider/internal/auth" @@ -27,7 +28,8 @@ import ( // and communicates with the Vault API. type provider struct { logger hclog.Logger - vaultResponseCache map[vaultResponseCacheKey]*api.Secret + vaultResponseCache map[vaultResponseCacheKey]*cachedSecretEntry + cacheTTL time.Duration // Allows mocking Kubernetes API for tests. authMethod *auth.KubernetesJWTAuth @@ -35,10 +37,16 @@ type provider struct { clientCache *clientcache.ClientCache } -func NewProvider(logger hclog.Logger, authMethod *auth.KubernetesJWTAuth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache) *provider { +type cachedSecretEntry struct { + Secret *api.Secret + CachedAt time.Time +} + +func NewProvider(logger hclog.Logger, authMethod *auth.KubernetesJWTAuth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache, cacheTTL time.Duration) *provider { p := &provider{ logger: logger, - vaultResponseCache: make(map[vaultResponseCacheKey]*api.Secret), + vaultResponseCache: make(map[vaultResponseCacheKey]*cachedSecretEntry), + cacheTTL: cacheTTL, authMethod: authMethod, hmacGenerator: hmacGenerator, @@ -106,12 +114,12 @@ func decodeValue(data []byte, encoding string) ([]byte, error) { } func (p *provider) getSecret(ctx context.Context, client *vaultclient.Client, secretConfig config.Secret) ([]byte, error) { - var secret *api.Secret - var cached bool key := vaultResponseCacheKey{secretPath: secretConfig.SecretPath, method: secretConfig.Method} - if secret, cached = p.vaultResponseCache[key]; !cached { + if cachedSecret, cached := p.vaultResponseCache[key]; !cached || isExpired(cachedSecret, p.cacheTTL) { + p.logger.Debug("Secret not cached or TTL expired", "secretConfig", secretConfig) + var err error - secret, err = client.RequestSecret(ctx, p.authMethod, secretConfig) + secret, err := client.RequestSecret(ctx, p.authMethod, secretConfig) if err != nil { return nil, fmt.Errorf("couldn't read secret %q: %w", secretConfig.ObjectName, err) } @@ -123,11 +131,16 @@ func (p *provider) getSecret(ctx context.Context, client *vaultclient.Client, se p.logger.Warn("Warning in response from Vault API", "warning", w) } - p.vaultResponseCache[key] = secret + p.vaultResponseCache[key] = &cachedSecretEntry{ + Secret: secret, + CachedAt: time.Now(), + } } else { p.logger.Debug("Secret fetched from cache", "secretConfig", secretConfig) } + secret := p.vaultResponseCache[key].Secret + // If no secretKey specified, we return the whole response as a JSON object. if secretConfig.SecretKey == "" { content, err := json.Marshal(secret) @@ -220,3 +233,14 @@ func generateObjectVersion(secret config.Secret, hmacKey []byte, content []byte) Version: base64.URLEncoding.EncodeToString(hash.Sum(nil)), }, nil } + +// isExpired checks if a cached secret has expired. +// Returns 'true' if the time elapsed since CachedAt is greater +// than or equal to the TTL. A TTL equal to or less than zero is treated as 'never expires'. +func isExpired(cachedSecret *cachedSecretEntry, ttl time.Duration) bool { + // TTL equals to 0 or negative indicates cache never expires + if ttl <= 0 { + return false + } + return time.Since(cachedSecret.CachedAt) >= ttl +} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 8d8e3b42..09e45e98 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -236,7 +236,7 @@ func TestHandleMountRequest(t *testing.T) { clientCache, err := clientcache.NewClientCache(hclog.Default(), 10) require.NoError(t, err) // While we hit the cache, the secret contents and versions should remain the same. - provider := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache) + provider := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache, 0) for i := 0; i < 3; i++ { resp, err := provider.HandleMountRequest(context.Background(), spcConfig, flagsConfig) require.NoError(t, err) @@ -252,7 +252,7 @@ func TestHandleMountRequest(t *testing.T) { // The mockVaultHandler function below includes a dynamic counter in the content of secrets. // That means mounting again with a fresh provider will update the contents of the secrets, which should update the version. - resp, err := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache).HandleMountRequest(context.Background(), spcConfig, flagsConfig) + resp, err := NewProvider(hclog.Default(), authMethod, hmacGenerator, clientCache, 0).HandleMountRequest(context.Background(), spcConfig, flagsConfig) require.NoError(t, err) assert.Equal(t, (*v1alpha1.Error)(nil), resp.Error) diff --git a/internal/server/server.go b/internal/server/server.go index eefb1e64..de73aa5e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ package server import ( "context" "fmt" + "time" "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault-csi-provider/internal/auth" @@ -54,7 +55,8 @@ func (s *Server) Mount(ctx context.Context, req *pb.MountRequest) (*pb.MountResp } authMethod := auth.NewKubernetesJWTAuth(s.logger.Named("auth"), s.k8sClient, cfg.Parameters, s.flagsConfig.VaultMount) - provider := provider.NewProvider(s.logger.Named("provider"), authMethod, s.hmacGenerator, s.clientCache) + cacheTTL := time.Duration(s.flagsConfig.CacheTTL) * time.Second + provider := provider.NewProvider(s.logger.Named("provider"), authMethod, s.hmacGenerator, s.clientCache, cacheTTL) resp, err := provider.HandleMountRequest(ctx, cfg, s.flagsConfig) if err != nil { return nil, fmt.Errorf("error making mount request: %w", err) diff --git a/main.go b/main.go index 2f536d5a..0a99e1c8 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func realMain(logger hclog.Logger) error { flag.StringVar(&flags.HMACSecretName, "hmac-secret-name", "vault-csi-provider-hmac-key", "Configure the Kubernetes secret name that the provider creates to store an HMAC key for generating secret version hashes") flag.IntVar(&flags.CacheSize, "cache-size", 1000, "Set the maximum number of Vault tokens that will be cached in-memory. One Vault token will be stored for each pod on the same node that mounts secrets.") + flag.IntVar(&flags.CacheTTL, "cache-ttl", 0, "TTL for cached secrets, specified in seconds. Defaults to 0 (never expires).") flag.StringVar(&flags.VaultAddr, "vault-addr", "", "Default address for connecting to Vault. Can also be specified via the VAULT_ADDR environment variable.") flag.StringVar(&flags.VaultMount, "vault-mount", "kubernetes", "Default Vault mount path for authentication. Can refer to a Kubernetes or JWT auth mount.")