Skip to content

Commit 6589e71

Browse files
committed
Implement GetByTokenAddressBatch for account info and use it in PublicDistribution intent
1 parent 8f5e3d7 commit 6589e71

File tree

12 files changed

+164
-60
lines changed

12 files changed

+164
-60
lines changed

pkg/code/data/account/memory/store.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ func (s *store) GetByTokenAddress(_ context.Context, address string) (*account.R
252252
return &cloned, nil
253253
}
254254

255+
// GetByTokenAddressBatch implements account.Store.GetByTokenAddressBatch
256+
func (s *store) GetByTokenAddressBatch(_ context.Context, addresses ...string) (map[string]*account.Record, error) {
257+
s.mu.Lock()
258+
defer s.mu.Unlock()
259+
260+
res := make(map[string]*account.Record)
261+
for _, address := range addresses {
262+
item := s.findByTokenAddress(address)
263+
if item == nil {
264+
return nil, account.ErrAccountInfoNotFound
265+
}
266+
267+
cloned := item.Clone()
268+
res[address] = &cloned
269+
}
270+
return res, nil
271+
}
272+
255273
// GetByAuthorityAddress implements account.Store.GetByAuthorityAddress
256274
func (s *store) GetByAuthorityAddress(_ context.Context, address string) (*account.Record, error) {
257275
s.mu.Lock()

pkg/code/data/account/postgres/model.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"database/sql"
66
"errors"
7+
"fmt"
8+
"strings"
79
"time"
810

911
"github.com/jmoiron/sqlx"
@@ -190,6 +192,30 @@ func dbGetByTokenAddress(ctx context.Context, db *sqlx.DB, address string) (*mod
190192
return res, nil
191193
}
192194

195+
func dbGetByTokenAddressBatch(ctx context.Context, db *sqlx.DB, addresses ...string) ([]*model, error) {
196+
res := []*model{}
197+
198+
individualFilters := make([]string, len(addresses))
199+
for i, address := range addresses {
200+
individualFilters[i] = fmt.Sprintf("'%s'", address)
201+
}
202+
203+
query := fmt.Sprintf(
204+
`SELECT id, owner_account, authority_account, token_account, mint_account, account_type, index, relationship_to, requires_deposit_sync, deposits_last_synced_at, requires_auto_return_check, requires_swap_retry, last_swap_retry_at, created_at FROM `+tableName+`
205+
WHERE token_account IN (%s)`,
206+
strings.Join(individualFilters, ", "),
207+
)
208+
209+
err := db.SelectContext(ctx, &res, query)
210+
if err != nil {
211+
return nil, pgutil.CheckNoRows(err, account.ErrAccountInfoNotFound)
212+
}
213+
if len(res) != len(addresses) {
214+
return nil, account.ErrAccountInfoNotFound
215+
}
216+
return res, nil
217+
}
218+
193219
func dbGetByAuthorityAddress(ctx context.Context, db *sqlx.DB, address string) (*model, error) {
194220
res := &model{}
195221

pkg/code/data/account/postgres/store.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ func (s *store) GetByTokenAddress(ctx context.Context, address string) (*account
6767
return fromModel(model), nil
6868
}
6969

70+
// GetByTokenAddressBatch implements timelock.Store.GetByTokenAddressBatch
71+
func (s *store) GetByTokenAddressBatch(ctx context.Context, addresses ...string) (map[string]*account.Record, error) {
72+
models, err := dbGetByTokenAddressBatch(ctx, s.db, addresses...)
73+
if err != nil {
74+
return nil, err
75+
}
76+
77+
recorsdByAddress := make(map[string]*account.Record, len(models))
78+
for _, model := range models {
79+
recorsdByAddress[model.TokenAccount] = fromModel(model)
80+
}
81+
return recorsdByAddress, nil
82+
}
83+
7084
// GetByAuthorityAddress implements account.Store.GetByAuthorityAddress
7185
func (s *store) GetByAuthorityAddress(ctx context.Context, address string) (*account.Record, error) {
7286
model, err := dbGetByAuthorityAddress(ctx, s.db, address)

pkg/code/data/account/store.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ type Store interface {
2525
// GetByTokenAddress finds the record for a given token account address
2626
GetByTokenAddress(ctx context.Context, address string) (*Record, error)
2727

28+
// GetByTokenAddressBatch is like GetByTokenAddress, but for multiple accounts.
29+
// If any one account is missing, ErrAccountInfoNotFound is returned.
30+
GetByTokenAddressBatch(ctx context.Context, addresses ...string) (map[string]*Record, error)
31+
2832
// GetByAuthorityAddress finds the record for a given authority account address
2933
GetByAuthorityAddress(ctx context.Context, address string) (*Record, error)
3034

pkg/code/data/account/tests/tests.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ func RunTests(t *testing.T, s account.Store, teardown func()) {
2121
testPutMultipleRecords,
2222
testPutErrors,
2323
testGetLatestByOwner,
24+
testBatchedMethods,
2425
testRemoteSendEdgeCases,
2526
testRelationshipAccountEdgeCases,
2627
testSwapAccountEdgeCases,
@@ -317,6 +318,49 @@ func testGetLatestByOwner(t *testing.T, s account.Store) {
317318
})
318319
}
319320

321+
func testBatchedMethods(t *testing.T, s account.Store) {
322+
t.Run("testBatchedMethods", func(t *testing.T) {
323+
ctx := context.Background()
324+
325+
var records []*account.Record
326+
for i := 0; i < 100; i++ {
327+
record := &account.Record{
328+
OwnerAccount: fmt.Sprintf("owner%d", i),
329+
AuthorityAccount: fmt.Sprintf("authority%d", i),
330+
TokenAccount: fmt.Sprintf("token%d", i),
331+
MintAccount: fmt.Sprintf("mint%d", i),
332+
AccountType: commonpb.AccountType_POOL,
333+
Index: uint64(i),
334+
}
335+
336+
require.NoError(t, s.Put(ctx, record))
337+
338+
records = append(records, record)
339+
}
340+
341+
actual, err := s.GetByTokenAddressBatch(ctx, "token0", "token1")
342+
require.NoError(t, err)
343+
require.Len(t, actual, 2)
344+
assertEquivalentRecords(t, records[0], actual[records[0].TokenAccount])
345+
assertEquivalentRecords(t, records[1], actual[records[1].TokenAccount])
346+
347+
actual, err = s.GetByTokenAddressBatch(ctx, "token0", "token1", "token2", "token3", "token4")
348+
require.NoError(t, err)
349+
require.Len(t, actual, 5)
350+
assertEquivalentRecords(t, records[0], actual[records[0].TokenAccount])
351+
assertEquivalentRecords(t, records[1], actual[records[1].TokenAccount])
352+
assertEquivalentRecords(t, records[2], actual[records[2].TokenAccount])
353+
assertEquivalentRecords(t, records[3], actual[records[3].TokenAccount])
354+
assertEquivalentRecords(t, records[4], actual[records[4].TokenAccount])
355+
356+
_, err = s.GetByTokenAddressBatch(ctx, "not-found")
357+
assert.Equal(t, account.ErrAccountInfoNotFound, err)
358+
359+
_, err = s.GetByTokenAddressBatch(ctx, "token0", "not-found")
360+
assert.Equal(t, account.ErrAccountInfoNotFound, err)
361+
})
362+
}
363+
320364
func testRemoteSendEdgeCases(t *testing.T, s account.Store) {
321365
t.Run("testRemoteSendEdgeCases", func(t *testing.T) {
322366
ctx := context.Background()

pkg/code/data/internal.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ type DatabaseData interface {
9494
CreateAccountInfo(ctx context.Context, record *account.Record) error
9595
UpdateAccountInfo(ctx context.Context, record *account.Record) error
9696
GetAccountInfoByTokenAddress(ctx context.Context, address string) (*account.Record, error)
97+
GetAccountInfoByTokenAddressBatch(ctx context.Context, addresses ...string) (map[string]*account.Record, error)
9798
GetAccountInfoByAuthorityAddress(ctx context.Context, address string) (*account.Record, error)
9899
GetLatestAccountInfosByOwnerAddress(ctx context.Context, address string) (map[commonpb.AccountType][]*account.Record, error)
99100
GetLatestAccountInfoByOwnerAddressAndType(ctx context.Context, address string, accountType commonpb.AccountType) (*account.Record, error)
@@ -364,6 +365,9 @@ func (dp *DatabaseProvider) UpdateAccountInfo(ctx context.Context, record *accou
364365
func (dp *DatabaseProvider) GetAccountInfoByTokenAddress(ctx context.Context, address string) (*account.Record, error) {
365366
return dp.accounts.GetByTokenAddress(ctx, address)
366367
}
368+
func (dp *DatabaseProvider) GetAccountInfoByTokenAddressBatch(ctx context.Context, addresses ...string) (map[string]*account.Record, error) {
369+
return dp.accounts.GetByTokenAddressBatch(ctx, addresses...)
370+
}
367371
func (dp *DatabaseProvider) GetAccountInfoByAuthorityAddress(ctx context.Context, address string) (*account.Record, error) {
368372
return dp.accounts.GetByAuthorityAddress(ctx, address)
369373
}

pkg/code/data/timelock/memory/store.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,6 @@ func (s *store) GetByVault(_ context.Context, vault string) (*timelock.Record, e
9191
return nil, timelock.ErrTimelockNotFound
9292
}
9393

94-
func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) {
95-
s.mu.Lock()
96-
defer s.mu.Unlock()
97-
98-
if item := s.findByDepositPda(depositPda); item != nil {
99-
return item.Clone(), nil
100-
}
101-
return nil, timelock.ErrTimelockNotFound
102-
}
103-
10494
// GetByVaultBatch implements timelock.Store.GetByVaultBatch
10595
func (s *store) GetByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error) {
10696
s.mu.Lock()
@@ -118,6 +108,16 @@ func (s *store) GetByVaultBatch(ctx context.Context, vaults ...string) (map[stri
118108
return res, nil
119109
}
120110

111+
func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) {
112+
s.mu.Lock()
113+
defer s.mu.Unlock()
114+
115+
if item := s.findByDepositPda(depositPda); item != nil {
116+
return item.Clone(), nil
117+
}
118+
return nil, timelock.ErrTimelockNotFound
119+
}
120+
121121
// GetAllByState implements timelock.Store.GetAllByState
122122
func (s *store) GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*timelock.Record, error) {
123123
s.mu.Lock()

pkg/code/data/timelock/postgres/model.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,6 @@ func dbGetByVault(ctx context.Context, db *sqlx.DB, vault string) (*model, error
174174
return res, nil
175175
}
176176

177-
func dbGetByDepositPda(ctx context.Context, db *sqlx.DB, depositPda string) (*model, error) {
178-
res := &model{}
179-
180-
query := `SELECT
181-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
182-
FROM ` + tableName + `
183-
WHERE deposit_pda_address = $1
184-
LIMIT 1`
185-
186-
err := db.GetContext(ctx, res, query, depositPda)
187-
if err != nil {
188-
return nil, pgutil.CheckNoRows(err, timelock.ErrTimelockNotFound)
189-
}
190-
return res, nil
191-
}
192-
193177
func dbGetByVaultBatch(ctx context.Context, db *sqlx.DB, vaults ...string) ([]*model, error) {
194178
res := []*model{}
195179

@@ -215,6 +199,22 @@ func dbGetByVaultBatch(ctx context.Context, db *sqlx.DB, vaults ...string) ([]*m
215199
return res, nil
216200
}
217201

202+
func dbGetByDepositPda(ctx context.Context, db *sqlx.DB, depositPda string) (*model, error) {
203+
res := &model{}
204+
205+
query := `SELECT
206+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
207+
FROM ` + tableName + `
208+
WHERE deposit_pda_address = $1
209+
LIMIT 1`
210+
211+
err := db.GetContext(ctx, res, query, depositPda)
212+
if err != nil {
213+
return nil, pgutil.CheckNoRows(err, timelock.ErrTimelockNotFound)
214+
}
215+
return res, nil
216+
}
217+
218218
func dbGetAllByState(ctx context.Context, db *sqlx.DB, state timelock_token.TimelockState, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*model, error) {
219219
res := []*model{}
220220

pkg/code/data/timelock/postgres/store.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ func (s *store) GetByVault(ctx context.Context, vault string) (*timelock.Record,
5959
return fromModel(model), nil
6060
}
6161

62-
// GetByDepositPda implements timelock.Store.GetByDepositPda
63-
func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) {
64-
model, err := dbGetByDepositPda(ctx, s.db, depositPda)
65-
if err != nil {
66-
return nil, err
67-
}
68-
69-
return fromModel(model), nil
70-
}
71-
7262
// GetByVaultBatch implements timelock.Store.GetByVaultBatch
7363
func (s *store) GetByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error) {
7464
models, err := dbGetByVaultBatch(ctx, s.db, vaults...)
@@ -83,6 +73,16 @@ func (s *store) GetByVaultBatch(ctx context.Context, vaults ...string) (map[stri
8373
return timelocksByVault, nil
8474
}
8575

76+
// GetByDepositPda implements timelock.Store.GetByDepositPda
77+
func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) {
78+
model, err := dbGetByDepositPda(ctx, s.db, depositPda)
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
return fromModel(model), nil
84+
}
85+
8686
// GetOldestByState implements timelock.Store.GetAllByState
8787
func (s *store) GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*timelock.Record, error) {
8888
res, err := dbGetAllByState(ctx, s.db, state, cursor, limit, direction)

pkg/code/data/timelock/store.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ type Store interface {
1717
// GetByVault gets a timelock account's state by the vault address it's locking
1818
GetByVault(ctx context.Context, vault string) (*Record, error)
1919

20-
// GetByDepositPda gets a timelock account's state by the deposit PDA address
21-
GetByDepositPda(ctx context.Context, depositPda string) (*Record, error)
22-
2320
// GetByVaultBatch is like GetByVault, but for multiple accounts. If any one account
2421
// is missing, ErrTimelockNotFound is returned.
2522
GetByVaultBatch(ctx context.Context, vaults ...string) (map[string]*Record, error)
2623

24+
// GetByDepositPda gets a timelock account's state by the deposit PDA address
25+
GetByDepositPda(ctx context.Context, depositPda string) (*Record, error)
26+
2727
// GetAllByState gets all timelock accounts in the provided state
2828
GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*Record, error)
2929

0 commit comments

Comments
 (0)