Skip to content

Commit aa3fabf

Browse files
committed
refactor: explicitly return at most one endpoint in generateTXTRecord
1 parent 6596ea1 commit aa3fabf

File tree

3 files changed

+49
-66
lines changed

3 files changed

+49
-66
lines changed

registry/txt.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,8 @@ func (im *TXTRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error
256256
// The migration is done for the TXT records owned by this instance only.
257257
if len(txtRecordsMap) > 0 && ep.Labels[endpoint.OwnerLabelKey] == im.ownerID {
258258
if plan.IsManagedRecord(ep.RecordType, im.managedRecordTypes, im.excludeRecordTypes) {
259-
// Get desired TXT records and detect the missing ones
260-
desiredTXTs := im.generateTXTRecord(ep)
261-
for _, desiredTXT := range desiredTXTs {
259+
// Get desired TXT record and check whether it is missing
260+
if desiredTXT := im.generateTXTRecord(ep); desiredTXT != nil {
262261
if _, exists := txtRecordsMap[desiredTXT.DNSName]; !exists {
263262
ep.WithProviderSpecific(providerSpecificForceUpdate, "true")
264263
}
@@ -276,13 +275,7 @@ func (im *TXTRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error
276275
return endpoints, nil
277276
}
278277

279-
func (im *TXTRegistry) generateTXTRecord(r *endpoint.Endpoint) []*endpoint.Endpoint {
280-
return im.generateTXTRecordWithFilter(r, func(ep *endpoint.Endpoint) bool { return true })
281-
}
282-
283-
func (im *TXTRegistry) generateTXTRecordWithFilter(r *endpoint.Endpoint, filter func(*endpoint.Endpoint) bool) []*endpoint.Endpoint {
284-
endpoints := make([]*endpoint.Endpoint, 0)
285-
278+
func (im *TXTRegistry) generateTXTRecord(r *endpoint.Endpoint) *endpoint.Endpoint {
286279
recordType := r.RecordType
287280
// AWS Alias records are encoded as type "cname"
288281
if isAlias, found := r.GetProviderSpecificProperty("alias"); found && isAlias == "true" && recordType == endpoint.RecordTypeA {
@@ -293,11 +286,8 @@ func (im *TXTRegistry) generateTXTRecordWithFilter(r *endpoint.Endpoint, filter
293286
txtNew.WithSetIdentifier(r.SetIdentifier)
294287
txtNew.Labels[endpoint.OwnedRecordLabelKey] = r.DNSName
295288
txtNew.ProviderSpecific = r.ProviderSpecific
296-
if filter(txtNew) {
297-
endpoints = append(endpoints, txtNew)
298-
}
299289
}
300-
return endpoints
290+
return txtNew
301291
}
302292

303293
// ApplyChanges updates dns provider with the changes
@@ -317,7 +307,9 @@ func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
317307
}
318308
r.Labels[endpoint.OwnerLabelKey] = im.ownerID
319309

320-
filteredChanges.Create = append(filteredChanges.Create, im.generateTXTRecordWithFilter(r, im.existingTXTs.isAbsent)...)
310+
if txt := im.generateTXTRecord(r); txt != nil && im.existingTXTs.isAbsent(txt) {
311+
filteredChanges.Create = append(filteredChanges.Create, txt)
312+
}
321313

322314
if im.cacheInterval > 0 {
323315
im.addToCache(r)
@@ -328,7 +320,9 @@ func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
328320
// when we delete TXT records for which value has changed (due to new label) this would still work because
329321
// !!! TXT record value is uniquely generated from the Labels of the endpoint. Hence old TXT record can be uniquely reconstructed
330322
// !!! After migration to the new TXT registry format we can drop records in old format here!!!
331-
filteredChanges.Delete = append(filteredChanges.Delete, im.generateTXTRecord(r)...)
323+
if txt := im.generateTXTRecord(r); txt != nil {
324+
filteredChanges.Delete = append(filteredChanges.Delete, txt)
325+
}
332326

333327
if im.cacheInterval > 0 {
334328
im.removeFromCache(r)
@@ -339,7 +333,9 @@ func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
339333
for _, r := range filteredChanges.UpdateOld {
340334
// when we updateOld TXT records for which value has changed (due to new label) this would still work because
341335
// !!! TXT record value is uniquely generated from the Labels of the endpoint. Hence old TXT record can be uniquely reconstructed
342-
filteredChanges.UpdateOld = append(filteredChanges.UpdateOld, im.generateTXTRecord(r)...)
336+
if txt := im.generateTXTRecord(r); txt != nil {
337+
filteredChanges.UpdateOld = append(filteredChanges.UpdateOld, txt)
338+
}
343339
// remove old version of record from cache
344340
if im.cacheInterval > 0 {
345341
im.removeFromCache(r)
@@ -348,7 +344,9 @@ func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
348344

349345
// make sure TXT records are consistently updated as well
350346
for _, r := range filteredChanges.UpdateNew {
351-
filteredChanges.UpdateNew = append(filteredChanges.UpdateNew, im.generateTXTRecord(r)...)
347+
if txt := im.generateTXTRecord(r); txt != nil {
348+
filteredChanges.UpdateNew = append(filteredChanges.UpdateNew, txt)
349+
}
352350
// add new version of record to cache
353351
if im.cacheInterval > 0 {
354352
im.addToCache(r)

registry/txt_encryption_test.go

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,26 @@ func TestGenerateTXTGenerateTextRecordEncryptionWihDecryption(t *testing.T) {
109109
key := []byte(k)
110110
r, err := NewTXTRegistry(p, "", "", "owner", time.Minute, "", []string{}, []string{}, true, key)
111111
assert.NoError(t, err, "Error creating TXT registry")
112-
txtRecords := r.generateTXTRecord(test.record)
113-
assert.Len(t, txtRecords, len(test.record.Targets))
112+
txt := r.generateTXTRecord(test.record)
113+
assert.NotNil(t, txt)
114114

115-
for _, txt := range txtRecords {
116-
// should return a TXT record with the encryption nonce label. At the moment nonce is not set as label.
117-
assert.NotContains(t, txt.Labels, "txt-encryption-nonce")
115+
// should return a TXT record with the encryption nonce label. At the moment nonce is not set as label.
116+
assert.NotContains(t, txt.Labels, "txt-encryption-nonce")
118117

119-
assert.Len(t, txt.Targets, 1)
120-
assert.LessOrEqual(t, len(txt.Targets), 1)
118+
assert.Len(t, txt.Targets, 1)
119+
assert.LessOrEqual(t, len(txt.Targets), 1)
121120

122-
// decrypt targets
123-
for _, target := range txtRecords[0].Targets {
124-
encryptedText, errUnquote := strconv.Unquote(target)
125-
assert.NoError(t, errUnquote, "Error unquoting the encrypted text")
121+
// decrypt targets
122+
for _, target := range txt.Targets {
123+
encryptedText, errUnquote := strconv.Unquote(target)
124+
assert.NoError(t, errUnquote, "Error unquoting the encrypted text")
126125

127-
actual, nonce, errDecrypt := endpoint.DecryptText(encryptedText, r.txtEncryptAESKey)
128-
assert.NoError(t, errDecrypt, "Error decrypting the encrypted text")
126+
actual, nonce, errDecrypt := endpoint.DecryptText(encryptedText, r.txtEncryptAESKey)
127+
assert.NoError(t, errDecrypt, "Error decrypting the encrypted text")
129128

130-
assert.True(t, strings.HasPrefix(encryptedText, nonce),
131-
"Nonce '%s' should be a prefix of the encrypted text: '%s'", nonce, encryptedText)
132-
assert.Equal(t, test.decrypted, actual)
133-
}
129+
assert.True(t, strings.HasPrefix(encryptedText, nonce),
130+
"Nonce '%s' should be a prefix of the encrypted text: '%s'", nonce, encryptedText)
131+
assert.Equal(t, test.decrypted, actual)
134132
}
135133
})
136134
}

registry/txt_test.go

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,14 +1516,12 @@ func TestNewTXTScheme(t *testing.T) {
15161516

15171517
func TestGenerateTXT(t *testing.T) {
15181518
record := newEndpointWithOwner("foo.test-zone.example.org", "new-foo.loadbalancer.com", endpoint.RecordTypeCNAME, "owner")
1519-
expectedTXT := []*endpoint.Endpoint{
1520-
{
1521-
DNSName: "cname-foo.test-zone.example.org",
1522-
Targets: endpoint.Targets{"\"heritage=external-dns,external-dns/owner=owner\""},
1523-
RecordType: endpoint.RecordTypeTXT,
1524-
Labels: map[string]string{
1525-
endpoint.OwnedRecordLabelKey: "foo.test-zone.example.org",
1526-
},
1519+
expectedTXT := &endpoint.Endpoint{
1520+
DNSName: "cname-foo.test-zone.example.org",
1521+
Targets: endpoint.Targets{"\"heritage=external-dns,external-dns/owner=owner\""},
1522+
RecordType: endpoint.RecordTypeTXT,
1523+
Labels: map[string]string{
1524+
endpoint.OwnedRecordLabelKey: "foo.test-zone.example.org",
15271525
},
15281526
}
15291527
p := inmemory.NewInMemoryProvider()
@@ -1535,14 +1533,12 @@ func TestGenerateTXT(t *testing.T) {
15351533

15361534
func TestGenerateTXTForAAAA(t *testing.T) {
15371535
record := newEndpointWithOwner("foo.test-zone.example.org", "2001:DB8::1", endpoint.RecordTypeAAAA, "owner")
1538-
expectedTXT := []*endpoint.Endpoint{
1539-
{
1540-
DNSName: "aaaa-foo.test-zone.example.org",
1541-
Targets: endpoint.Targets{"\"heritage=external-dns,external-dns/owner=owner\""},
1542-
RecordType: endpoint.RecordTypeTXT,
1543-
Labels: map[string]string{
1544-
endpoint.OwnedRecordLabelKey: "foo.test-zone.example.org",
1545-
},
1536+
expectedTXT := &endpoint.Endpoint{
1537+
DNSName: "aaaa-foo.test-zone.example.org",
1538+
Targets: endpoint.Targets{"\"heritage=external-dns,external-dns/owner=owner\""},
1539+
RecordType: endpoint.RecordTypeTXT,
1540+
Labels: map[string]string{
1541+
endpoint.OwnedRecordLabelKey: "foo.test-zone.example.org",
15461542
},
15471543
}
15481544
p := inmemory.NewInMemoryProvider()
@@ -1560,8 +1556,8 @@ func TestFailGenerateTXT(t *testing.T) {
15601556
RecordType: endpoint.RecordTypeCNAME,
15611557
Labels: map[string]string{},
15621558
}
1563-
// A bad DNS name returns empty expected TXT
1564-
expectedTXT := []*endpoint.Endpoint{}
1559+
// A bad DNS name returns nil
1560+
var expectedTXT *endpoint.Endpoint
15651561
p := inmemory.NewInMemoryProvider()
15661562
p.CreateZone(testZone)
15671563
r, _ := NewTXTRegistry(p, "", "", "owner", time.Hour, "", []string{}, []string{}, false, nil)
@@ -1714,23 +1710,14 @@ func TestGenerateTXTRecordWithNewFormatOnly(t *testing.T) {
17141710
for _, tc := range testCases {
17151711
t.Run(tc.name, func(t *testing.T) {
17161712
r, _ := NewTXTRegistry(p, "", "", "owner", time.Hour, "", []string{}, []string{}, false, nil)
1717-
records := r.generateTXTRecord(tc.endpoint)
1713+
txt := r.generateTXTRecord(tc.endpoint)
17181714

1719-
assert.Len(t, records, tc.expectedRecords, tc.description)
1715+
assert.NotNil(t, txt, tc.description)
17201716

1721-
for _, record := range records {
1722-
assert.Equal(t, endpoint.RecordTypeTXT, record.RecordType)
1723-
}
1717+
assert.Equal(t, endpoint.RecordTypeTXT, txt.RecordType)
17241718

17251719
if tc.endpoint.RecordType == endpoint.RecordTypeAAAA {
1726-
hasNewFormat := false
1727-
for _, record := range records {
1728-
if strings.HasPrefix(record.DNSName, tc.expectedPrefix) {
1729-
hasNewFormat = true
1730-
break
1731-
}
1732-
}
1733-
assert.True(t, hasNewFormat,
1720+
assert.True(t, strings.HasPrefix(txt.DNSName, tc.expectedPrefix),
17341721
"Should have at least one record with prefix %s when using new format", tc.expectedPrefix)
17351722
}
17361723
})

0 commit comments

Comments
 (0)