Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions internal/server/spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,40 @@ package spanner
import (
"context"
"fmt"
"sync"
"time"

"cloud.google.com/go/spanner"
"gopkg.in/yaml.v3"
)

const (
// CACHE_DURATION defines how long the CompletionTimestamp is kept in memory before being refetched.
CACHE_DURATION = 5 * time.Second
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we rename this to TIMESTAMP_CACHE_DURATION just to be explicit

)

// SpannerClient encapsulates the Spanner client.
type SpannerClient struct {
client *spanner.Client

// Cache for storing CompletionTimestamp for stale reads.
cacheMutex sync.RWMutex
cachedTimestamp *time.Time
cacheExpiry time.Time

// For mocking in tests.
timestampFetcher func(context.Context) (*time.Time, error)
clock func() time.Time
}

// newSpannerClient creates a new SpannerClient.
func newSpannerClient(client *spanner.Client) *SpannerClient {
return &SpannerClient{client: client}
sc := &SpannerClient{
client: client,
clock: time.Now, // Default to real time
}
sc.timestampFetcher = sc.fetchCompletionTimestampFromSpanner
return sc
}

// NewSpannerClient creates a new SpannerClient from the config yaml string.
Expand All @@ -43,7 +64,15 @@ func NewSpannerClient(ctx context.Context, spannerConfigYaml string) (*SpannerCl
if err != nil {
return nil, fmt.Errorf("failed to create SpannerClient: %w", err)
}
return newSpannerClient(client), nil
sc := newSpannerClient(client)

// Cache initial CompletionTimestamp
_, err = sc.GetStalenessTimestampBound(ctx)
if err != nil {
return nil, fmt.Errorf("failed to warm up stable timestamp cache: %w", err)
}

return sc, nil
}

// createSpannerClient creates the database name string and initializes the Spanner client.
Expand Down
118 changes: 118 additions & 0 deletions internal/server/spanner/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package spanner

import (
"context"
"fmt"
"strings"
"testing"
"time"
)

func TestCacheHit(t *testing.T) {
var fetchCount int
mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC)
stableTime := mockTime.Add(-1 * time.Minute)

sc := &SpannerClient{clock: func() time.Time { return mockTime }}
sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) {
fetchCount++
return &stableTime, nil
}
// Initialization will populate cache.
_, err := sc.getCompletionTimestamp(context.Background())
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if fetchCount != 1 {
t.Fatalf("Setup failed, expected 1 fetch, got %d", fetchCount)
}

// This call is immediately after initialization, within the 5-second duration.
_, err = sc.getCompletionTimestamp(context.Background())
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if fetchCount != 1 {
t.Errorf("Expected timestamp fetch count to remain 1, got %d", fetchCount)
}
}

func TestCacheExpiration(t *testing.T) {
var fetchCount int
mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC)
stableTime := mockTime.Add(-1 * time.Minute)

sc := &SpannerClient{
cacheExpiry: mockTime.Add(CACHE_DURATION),
clock: func() time.Time { return mockTime },
}
sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) {
fetchCount++
return &stableTime, nil
}
// Initialization will populate cache.
_, err := sc.getCompletionTimestamp(context.Background())
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if fetchCount != 1 {
t.Fatalf("Setup failed, expected 1 fetch, got %d", fetchCount)
}

// Advance time past expiration.
mockTime = mockTime.Add(6 * time.Second)
_, err = sc.getCompletionTimestamp(context.Background())
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if fetchCount != 2 {
t.Errorf("Expected timestamp fetch count to increase to 2, got %d", fetchCount)
}
expectedExpiry := mockTime.Add(CACHE_DURATION)
if sc.cacheExpiry.Sub(expectedExpiry) > time.Millisecond {
t.Errorf("Cache expiry was not correctly updated after refetch.")
}
}

func TestGetStalenessTimestampBound(t *testing.T) {
mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC)
stableTime := mockTime.Add(-5 * time.Minute) // Stable time is 5 minutes ago

sc := &SpannerClient{
cacheExpiry: mockTime.Add(CACHE_DURATION),
clock: func() time.Time { return mockTime },
}
sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) {
return &stableTime, nil
}

timestamp, err := sc.GetStalenessTimestampBound(context.Background())

if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if timestamp == nil {
t.Fatal("Expected a non-nil TimestampBound")
} else {
// Approximate check relying on String() representation of ReadTimestamp.
expectedString := fmt.Sprintf("ReadTimestamp(%s)", stableTime.Format(time.RFC3339Nano))
actualString := (*timestamp).String()
if !strings.Contains(actualString, stableTime.Format("2006-01-02")) {
t.Errorf("Expected ReadTimestamp containing %v, got %s", expectedString, actualString)
}
}
}
103 changes: 100 additions & 3 deletions internal/server/spanner/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package spanner
import (
"context"
"fmt"
"log/slog"
"time"

"cloud.google.com/go/spanner"
v2 "github.com/datacommonsorg/mixer/internal/server/v2"
"google.golang.org/api/iterator"
"google.golang.org/grpc/codes"
)

const (
Expand Down Expand Up @@ -211,16 +214,47 @@ func (sc *SpannerClient) queryAndCollect(
newStruct func() interface{},
withStruct func(interface{}),
) error {
iter := sc.client.Single().Query(ctx, stmt)
timestampBound, err := sc.GetStalenessTimestampBound(ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a feature flag to disable stale reads?

if err != nil {
return err
}

// Attempt stale read
iter := sc.client.Single().WithTimestampBound(*timestampBound).Query(ctx, stmt)
defer iter.Stop()

err = sc.processRows(iter, newStruct, withStruct)

// Check if the error is due to an expired timestamp (FAILED_PRECONDITION).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious when can that happen...max 7 day timeout?

// Currently the timestamp is set manually so can naturally get stale.
// So for now, just log an error and fallback to a strong read.
// TODO: Once the Spanner instance is set to periodically update the timestamp, increase severity of check, as this indicates that ingestion failed.
if spanner.ErrCode(err) == codes.FailedPrecondition {
slog.Error("Stale read timestamp expired (before earliest_version_time). Falling back to StrongRead.",
"expiredTimestamp", timestampBound.String())

// Fallback to strong read
strongBound := spanner.StrongRead()
iter = sc.client.Single().WithTimestampBound(strongBound).Query(ctx, stmt)
defer iter.Stop()

err = sc.processRows(iter, newStruct, withStruct)
}
if err != nil {
return fmt.Errorf("failed to execute Spanner query after fallback attempt: %w", err)
}

return nil
}

func (sc *SpannerClient) processRows(iter *spanner.RowIterator, newStruct func() interface{}, withStruct func(interface{})) error {
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return fmt.Errorf("failed to fetch row: %w", err)
return err
}

rowStruct := newStruct()
Expand All @@ -229,6 +263,69 @@ func (sc *SpannerClient) queryAndCollect(
}
withStruct(rowStruct)
}

return nil
}

// fetchCompletionTimestampFromSpanner returns the latest reported CompletionTimestamp in IngestionHistory.
func (sc *SpannerClient) fetchCompletionTimestampFromSpanner(ctx context.Context) (*time.Time, error) {
iter := sc.client.Single().Query(ctx, *GetCompletionTimestampQuery())
defer iter.Stop()

row, err := iter.Next()
if err == iterator.Done {
return nil, fmt.Errorf("no rows found in IngestionHistory")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch row: %w", err)
}

var timestamp time.Time
if err := row.Column(0, &timestamp); err != nil {
return nil, fmt.Errorf("failed to read CompletionTimestamp column: %w", err)
}

return &timestamp, nil
}

// getCompletionTimestamp returns the latest reported CompletionTimestamp.
// It prioritizes returning a value from an in-memory cache to reduce Spanner traffic.
func (sc *SpannerClient) getCompletionTimestamp(ctx context.Context) (*time.Time, error) {
// Check cache
sc.cacheMutex.RLock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, we need to think through how consistency would be ensured across caches in different mixer instances.

if sc.cachedTimestamp != nil && sc.clock().Before(sc.cacheExpiry) {
sc.cacheMutex.RUnlock()
return sc.cachedTimestamp, nil
}
sc.cacheMutex.RUnlock()

// Fetch from Spanner
sc.cacheMutex.Lock()
defer sc.cacheMutex.Unlock()

// Re-check the cache under the write lock (to prevent a race condition
// where another goroutine updated it between the RUnlock and this Lock)
if sc.cachedTimestamp != nil && sc.clock().Before(sc.cacheExpiry) {
return sc.cachedTimestamp, nil
}
timestamp, err := sc.timestampFetcher(ctx)
if err != nil {
return nil, err
}

// Update cache
sc.cachedTimestamp = timestamp
sc.cacheExpiry = sc.clock().Add(CACHE_DURATION)

return timestamp, nil
}

// GetStalenessTimestampBound returns the TimestampBound that should be used for stale reads in Spanner.
func (sc *SpannerClient) GetStalenessTimestampBound(ctx context.Context) (*spanner.TimestampBound, error) {
completionTimestamp, err := sc.getCompletionTimestamp(ctx)
if err != nil {
return nil, err
}

timestampBound := spanner.ReadTimestamp(*completionTimestamp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what's the difference b/w completion timestamp and timestampBound?

return &timestampBound, nil
}
6 changes: 6 additions & 0 deletions internal/server/spanner/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ import (
v2 "github.com/datacommonsorg/mixer/internal/server/v2"
)

func GetCompletionTimestampQuery() *spanner.Statement {
return &spanner.Statement{
SQL: statements.getCompletionTimestamp,
}
}

func GetNodePropsQuery(ids []string, out bool) *spanner.Statement {
switch out {
case true:
Expand Down
9 changes: 9 additions & 0 deletions internal/server/spanner/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (

// SQL / GQL statements executed by the SpannerClient
var statements = struct {
// Fetch latest CompletionTimestamp from IngestionHistory table.
getCompletionTimestamp string
// Fetch Properties for out arcs.
getPropsBySubjectID string
// Fetch Properties for in arcs.
Expand Down Expand Up @@ -74,6 +76,13 @@ var statements = struct {
// Resolve one property to another.
resolvePropToProp string
}{
getCompletionTimestamp: ` SELECT
CompletionTimestamp
FROM
IngestionHistory
ORDER BY
CompletionTimestamp DESC
LIMIT 1`,
getPropsBySubjectID: ` GRAPH DCGraph MATCH -[e:Edge
WHERE
e.subject_id IN UNNEST(@ids)]->
Expand Down
12 changes: 11 additions & 1 deletion test/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"encoding/json"
"log"
"log/slog"
"net"
"os"
"path"
Expand Down Expand Up @@ -392,7 +393,16 @@ func NewSpannerClient() *spanner.SpannerClient {
}
_, filename, _, _ := runtime.Caller(0)
spannerGraphInfoYamlPath := path.Join(path.Dir(filename), "../deploy/storage/spanner_graph_info.yaml")
return newSpannerClient(context.Background(), spannerGraphInfoYamlPath)
sc := newSpannerClient(context.Background(), spannerGraphInfoYamlPath)

// Cache initial CompletionTimestamp
_, err := sc.GetStalenessTimestampBound(context.Background())
if err != nil {
slog.Error("failed to warm up stable timestamp cache", "error", err)
return nil
}

return sc
}

func newSpannerClient(ctx context.Context, spannerGraphInfoYamlPath string) *spanner.SpannerClient {
Expand Down