diff --git a/lib/ocrypto/interfaces.go b/lib/ocrypto/interfaces.go new file mode 100644 index 000000000..4a339a2b4 --- /dev/null +++ b/lib/ocrypto/interfaces.go @@ -0,0 +1,30 @@ +package ocrypto + +import ( + "context" +) + +// Encapsulator interface for key encapsulation operations +type Encapsulator interface { + // Encrypt wraps a secret key with the encapsulation key + Encrypt(data []byte) ([]byte, error) + + // PublicKeyInPemFormat Returns public key in pem format, or the empty string if not present + PublicKeyInPemFormat() (string, error) + + // For EC schemes, this method returns the public part of the ephemeral key. + // Otherwise, it returns nil. + EphemeralKey() []byte +} + +// ProtectedKey represents a decrypted key with operations that can be performed on it +type ProtectedKey interface { + // VerifyBinding checks if the policy binding matches the given policy data + VerifyBinding(ctx context.Context, policy, binding []byte) error + + // Export returns the raw key data, optionally encrypting it with the provided encryptor + Export(encryptor Encapsulator) ([]byte, error) + + // Used to decrypt encrypted policies and metadata + DecryptAESGCM(iv []byte, body []byte, tagSize int) ([]byte, error) +} diff --git a/lib/ocrypto/protected_key.go b/lib/ocrypto/protected_key.go new file mode 100644 index 000000000..67aa36334 --- /dev/null +++ b/lib/ocrypto/protected_key.go @@ -0,0 +1,89 @@ +package ocrypto + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "errors" + "fmt" +) + +var ( + // ErrEmptyKeyData is returned when the key data is empty + ErrEmptyKeyData = errors.New("key data is empty") + // ErrPolicyHMACMismatch is returned when policy binding verification fails + ErrPolicyHMACMismatch = errors.New("policy hmac mismatch") +) + +// AESProtectedKey implements the ProtectedKey interface with an in-memory secret key +type AESProtectedKey struct { + rawKey []byte + aesGcm AesGcm +} + +var _ ProtectedKey = (*AESProtectedKey)(nil) + +// NewAESProtectedKey creates a new instance of AESProtectedKey +func NewAESProtectedKey(rawKey []byte) (*AESProtectedKey, error) { + if len(rawKey) == 0 { + return nil, ErrEmptyKeyData + } + // Create a defensive copy of the key + keyCopy := append([]byte{}, rawKey...) + + // Pre-initialize the AES-GCM cipher for performance + aesGcm, err := NewAESGcm(keyCopy) + if err != nil { + return nil, fmt.Errorf("failed to initialize AES-GCM cipher: %w", err) + } + + return &AESProtectedKey{ + rawKey: keyCopy, + aesGcm: aesGcm, + }, nil +} + +// DecryptAESGCM decrypts data using AES-GCM with the protected key +func (k *AESProtectedKey) DecryptAESGCM(iv []byte, body []byte, tagSize int) ([]byte, error) { + // Use the pre-initialized AES-GCM cipher for better performance + decryptedData, err := k.aesGcm.DecryptWithIVAndTagSize(iv, body, tagSize) + if err != nil { + return nil, fmt.Errorf("AES-GCM decryption failed: %w", err) + } + + return decryptedData, nil +} + +// Export returns the raw key data, optionally encrypting it with the provided Encapsulator +func (k *AESProtectedKey) Export(encapsulator Encapsulator) ([]byte, error) { + if encapsulator == nil { + // Return raw key data without encryption - caller should be aware of this + return append([]byte{}, k.rawKey...), nil + } + + // Encrypt the key data before returning + encryptedKey, err := encapsulator.Encrypt(k.rawKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt key data for export: %w", err) + } + + return encryptedKey, nil +} + +// VerifyBinding checks if the policy binding matches the given policy data +func (k *AESProtectedKey) VerifyBinding(_ context.Context, policy, policyBinding []byte) error { + actualHMAC := k.generateHMACDigest(policy) + + if !hmac.Equal(actualHMAC, policyBinding) { + return ErrPolicyHMACMismatch + } + + return nil +} + +// generateHMACDigest is a helper to generate an HMAC digest from a message using the key +func (k *AESProtectedKey) generateHMACDigest(msg []byte) []byte { + mac := hmac.New(sha256.New, k.rawKey) + mac.Write(msg) + return mac.Sum(nil) +} diff --git a/lib/ocrypto/protected_key_test.go b/lib/ocrypto/protected_key_test.go new file mode 100644 index 000000000..182f2591a --- /dev/null +++ b/lib/ocrypto/protected_key_test.go @@ -0,0 +1,198 @@ +package ocrypto + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAESProtectedKey(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + assert.NotNil(t, protectedKey) + assert.Equal(t, key, protectedKey.rawKey) +} + +func TestAESProtectedKey_DecryptAESGCM(t *testing.T) { + // Generate a random 256-bit key + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + // Test data + plaintext := []byte("Hello, World!") + + // Encrypt the data first using the same key + aesGcm, err := NewAESGcm(key) + require.NoError(t, err) + + encrypted, err := aesGcm.Encrypt(plaintext) + require.NoError(t, err) + + // Extract IV and ciphertext (first 12 bytes are IV for GCM standard nonce size) + iv := encrypted[:GcmStandardNonceSize] + ciphertext := encrypted[GcmStandardNonceSize:] + + // Test decryption + decrypted, err := protectedKey.DecryptAESGCM(iv, ciphertext, 16) // 16 is standard GCM tag size + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestAESProtectedKey_DecryptAESGCM_InvalidKey(t *testing.T) { + // Empty key should fail + _, err := NewAESProtectedKey([]byte{}) + require.Error(t, err) + assert.ErrorIs(t, err, ErrEmptyKeyData) +} + +func TestAESProtectedKey_Export_NoEncapsulator(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + exported, err := protectedKey.Export(nil) + require.NoError(t, err) + assert.Equal(t, key, exported) +} + +func TestAESProtectedKey_Export_WithEncapsulator(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + // Mock encapsulator + mockEncapsulator := &mockEncapsulator{ + encryptFunc: func(data []byte) ([]byte, error) { + // Simple XOR encryption for testing + result := make([]byte, len(data)) + for i, b := range data { + result[i] = b ^ 0xFF + } + return result, nil + }, + } + + exported, err := protectedKey.Export(mockEncapsulator) + require.NoError(t, err) + + // Verify it was encrypted (should be different from original) + assert.NotEqual(t, key, exported) + assert.Len(t, exported, len(key)) + + // Verify we can decrypt it back + for i, b := range exported { + assert.Equal(t, key[i], b^0xFF) + } +} + +func TestAESProtectedKey_Export_EncapsulatorError(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + mockEncapsulator := &mockEncapsulator{ + encryptFunc: func(_ []byte) ([]byte, error) { + return nil, assert.AnError + }, + } + + _, err = protectedKey.Export(mockEncapsulator) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to encrypt key data for export") +} + +func TestAESProtectedKey_VerifyBinding(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + policy := []byte("test-policy-data") + ctx := context.Background() + + // Generate the expected HMAC + expectedHMAC := protectedKey.generateHMACDigest(policy) + + // Verify binding should succeed with correct HMAC + err = protectedKey.VerifyBinding(ctx, policy, expectedHMAC) + assert.NoError(t, err) +} + +func TestAESProtectedKey_VerifyBinding_Mismatch(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + policy := []byte("test-policy-data") + wrongBinding := []byte("wrong-binding-data") + ctx := context.Background() + + err = protectedKey.VerifyBinding(ctx, policy, wrongBinding) + require.Error(t, err) + assert.Equal(t, ErrPolicyHMACMismatch, err) +} + +func TestAESProtectedKey_VerifyBinding_DifferentPolicyData(t *testing.T) { + key := []byte("test-key-12345678901234567890123") // 32 bytes + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + ctx := context.Background() + + // Generate HMAC for first policy + policy1 := []byte("policy-data-1") + hmac1 := protectedKey.generateHMACDigest(policy1) + + // Try to verify with different policy data + policy2 := []byte("policy-data-2") + err = protectedKey.VerifyBinding(ctx, policy2, hmac1) + require.Error(t, err) + assert.Equal(t, ErrPolicyHMACMismatch, err) +} + +func TestAESProtectedKey_InterfaceCompliance(t *testing.T) { + key := make([]byte, 32) + protectedKey, err := NewAESProtectedKey(key) + require.NoError(t, err) + + // Ensure it implements the ProtectedKey interface + assert.Implements(t, (*ProtectedKey)(nil), protectedKey) +} + +// Mock encapsulator for testing +type mockEncapsulator struct { + encryptFunc func([]byte) ([]byte, error) + publicKeyPEMFunc func() (string, error) + ephemeralKeyFunc func() []byte +} + +func (m *mockEncapsulator) Encrypt(data []byte) ([]byte, error) { + if m.encryptFunc != nil { + return m.encryptFunc(data) + } + return data, nil +} + +func (m *mockEncapsulator) PublicKeyInPemFormat() (string, error) { + if m.publicKeyPEMFunc != nil { + return m.publicKeyPEMFunc() + } + return "", nil +} + +func (m *mockEncapsulator) EphemeralKey() []byte { + if m.ephemeralKeyFunc != nil { + return m.ephemeralKeyFunc() + } + return nil +}