diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs b/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs new file mode 100644 index 000000000..88705057d --- /dev/null +++ b/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs @@ -0,0 +1,176 @@ +using Microsoft.Extensions.Logging; +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; + +namespace StackExchange.Redis; + +public partial class ConnectionMultiplexer +{ + /// + /// Enable the client tracking feature of redis + /// + /// see also https://redis.io/docs/manual/client-side-caching/ + /// The callback to be invoked when keys are determined to be invalidated + /// Additional flags to influence the behavior of client tracking + /// Optionally restricts client-side caching notifications for these connections to a subset of key prefixes; this has performance implications (see the PREFIX option in CLIENT TRACKING) + public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) + { + if (_clientSideTracking is not null) ThrowOnceOnly(); + if (!prefixes.IsEmpty && (options & ClientTrackingOptions.Broadcast) == 0) ThrowPrefixNeedsBroadcast(); + var obj = new ClientSideTrackingState(this, keyInvalidated, options, prefixes); + if (Interlocked.CompareExchange(ref _clientSideTracking, obj, null) is not null) ThrowOnceOnly(); + + static void ThrowOnceOnly() => throw new InvalidOperationException("The " + nameof(EnableServerAssistedClientSideTracking) + " method can be invoked once-only per multiplexer instance"); + static void ThrowPrefixNeedsBroadcast() => throw new ArgumentException("Prefixes can only be specified when " + nameof(ClientTrackingOptions) + "." + nameof(ClientTrackingOptions.Broadcast) + " is used", nameof(prefixes)); + } + + private ClientSideTrackingState? _clientSideTracking; + internal ClientSideTrackingState? ClientSideTracking => _clientSideTracking; + internal sealed class ClientSideTrackingState + { + public bool IsAlive { get; private set; } + private readonly Func _keyInvalidated; + public ClientTrackingOptions Options { get; } + public ReadOnlyMemory Prefixes { get; } + + private readonly Channel _notifications; + private readonly WeakReference _multiplexer; +#if NETCOREAPP3_1_OR_GREATER + private readonly Action? _concurrentCallback; +#else + private readonly WaitCallback? _concurrentCallback; +#endif + + public ClientSideTrackingState(ConnectionMultiplexer multiplexer, Func keyInvalidated, ClientTrackingOptions options, ReadOnlyMemory prefixes) + { + _keyInvalidated = keyInvalidated; + Options = options; + Prefixes = prefixes; + _notifications = Channel.CreateUnbounded(ChannelOptions); + _ = Task.Run(RunAsync); + IsAlive = true; + _multiplexer = new(multiplexer); + + if ((options & ClientTrackingOptions.ConcurrentInvalidation) != 0) + { + _concurrentCallback = OnInvalidate; + } + } + +#if !NETCOREAPP3_1_OR_GREATER + private void OnInvalidate(object state) => OnInvalidate((RedisKey)state); +#endif + + private void OnInvalidate(RedisKey key) + { + try // not optimized for sync completions + { + var pending = _keyInvalidated(key); + if (pending.IsCompleted) + { // observe result + pending.GetAwaiter().GetResult(); + } + else + { + _ = ObserveAsyncInvalidation(pending); + } + } + catch (Exception ex) // handle sync failure (via immediate throw or faulted ValueTask) + { + OnCallbackError(ex); + } + } + + private async Task ObserveAsyncInvalidation(ValueTask pending) + { + try + { + await pending.ConfigureAwait(false); + } + catch (Exception ex) + { + OnCallbackError(ex); + } + } + + private ConnectionMultiplexer? Multiplexer => _multiplexer.TryGetTarget(out var multiplexer) ? multiplexer : null; + + + private void OnCallbackError(Exception error) => Multiplexer?.Logger?.LogError(error, "Client-side tracking invalidation callback failure"); + + private async Task RunAsync() + { + while (await _notifications.Reader.WaitToReadAsync().ConfigureAwait(false)) + { + while (_notifications.Reader.TryRead(out var key)) + { + if (_concurrentCallback is not null) + { +#if NETCOREAPP3_1_OR_GREATER + ThreadPool.QueueUserWorkItem(_concurrentCallback, key, preferLocal: false); +#else + // eat the box + ThreadPool.QueueUserWorkItem(_concurrentCallback, key); +#endif + } + else + { + try + { + await _keyInvalidated(key).ConfigureAwait(false); + } + catch (Exception ex) + { + OnCallbackError(ex); + } + } + } + } + } + + public void Write(RedisKey key) => _notifications.Writer.TryWrite(key); + + public void Shutdown() + { + IsAlive = false; + _notifications.Writer.TryComplete(null); + } + + private static readonly UnboundedChannelOptions ChannelOptions = new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = true }; + + + } +} + +/// +/// Additional flags to influence the behavior of client tracking +/// +[Flags] +public enum ClientTrackingOptions +{ + /// + /// No additional options + /// + None = 0, + /// + /// Enable tracking in broadcasting mode. In this mode invalidation messages are reported for all the prefixes specified, regardless of the keys requested by the connection. Instead when the broadcasting mode is not enabled, Redis will track which keys are fetched using read-only commands, and will report invalidation messages only for such keys. + /// + /// This corresponds to CLIENT TRACKING ... BCAST; using mode consumes less server memory, at the cost of more invalidation messages (i.e. clients are + /// likely to receive invalidation messages for keys that the individual client is not using); this can be partially mitigated by using prefixes + Broadcast = 1 << 0, + /// + /// Send notifications about keys modified by this connection itself. + /// + /// This corresponds to the inverse of CLIENT TRACKING ... NOLOOP; setting means that your own writes will cause self-notification; this + /// may mean that you discard a locally updated copy of the new value, hence this is disabled by default + NotifyForOwnCommands = 1 << 1, + + /// + /// Indicates that the callback specified for key invalidation should be invoked concurrently rather than sequentially + /// + ConcurrentInvalidation = 1 << 2, + + // to think about: OPTIN / OPTOUT ? I'm happy to implement on the basis of OPTIN for now, though +} diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs b/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs index 0a8b95be5..707ba7bf6 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs @@ -14,6 +14,12 @@ public partial class ConnectionMultiplexer internal void OnConnectionFailed(EndPoint endpoint, ConnectionType connectionType, ConnectionFailureType failureType, Exception exception, bool reconfigure, string? physicalName) { if (_isDisposed) return; + + if (connectionType is ConnectionType.Subscription) + { + GetServerEndPoint(endpoint, activate: false)?.OnSubscriberFailed(); + } + var handler = ConnectionFailed; if (handler != null) { diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.cs b/src/StackExchange.Redis/ConnectionMultiplexer.cs index cc239ad3f..75060fb01 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.cs @@ -1,4 +1,7 @@ -using System; +using Microsoft.Extensions.Logging; +using Pipelines.Sockets.Unofficial; +using StackExchange.Redis.Profiling; +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -10,9 +13,6 @@ using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Pipelines.Sockets.Unofficial; -using StackExchange.Redis.Profiling; namespace StackExchange.Redis { @@ -355,6 +355,11 @@ internal void CheckMessage(Message message) { throw ExceptionFactory.TooManyArgs(message.CommandAndKey, message.ArgCount); } + + if (message.IsClientCaching && ClientSideTracking is null) + { + throw new InvalidOperationException("The " + nameof(CommandFlags.ClientCaching) + " flag can only be used if " + nameof(EnableServerAssistedClientSideTracking) + " has been called"); + } } internal bool TryResend(int hashSlot, Message message, EndPoint endpoint, bool isMoved) @@ -2268,7 +2273,7 @@ public async ValueTask DisposeAsync() public void Close(bool allowCommandsToComplete = true) { if (_isDisposed) return; - + _clientSideTracking?.Shutdown(); OnClosing(false); _isDisposed = true; _profilingSessionProvider = null; @@ -2295,6 +2300,7 @@ public void Close(bool allowCommandsToComplete = true) public async Task CloseAsync(bool allowCommandsToComplete = true) { _isDisposed = true; + _clientSideTracking?.Shutdown(); using (var tmp = pulse) { pulse = null; diff --git a/src/StackExchange.Redis/Enums/CommandFlags.cs b/src/StackExchange.Redis/Enums/CommandFlags.cs index bc93f328e..ddcec208d 100644 --- a/src/StackExchange.Redis/Enums/CommandFlags.cs +++ b/src/StackExchange.Redis/Enums/CommandFlags.cs @@ -81,7 +81,10 @@ public enum CommandFlags /// NoScriptCache = 512, - // 1024: Removed - was used for async timeout checks; never user-specified, so not visible on the public API + /// + /// Indicates a command that relates to server-assisted client-side caching; this corresponds to CLIENT CACHING YES being issues before the command + /// + ClientCaching = 1024, // 2048: Use subscription connection type; never user-specified, so not visible on the public API } diff --git a/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs b/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs index 58973df68..3340d132f 100644 --- a/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs @@ -281,5 +281,8 @@ public interface IConnectionMultiplexer : IDisposable, IAsyncDisposable /// The destination stream to write the export to. /// The options to use for this export. void ExportConfiguration(Stream destination, ExportOptions options = ExportOptions.All); + + /// + void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default); } } diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 65df7d0e7..c954b6e2f 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -75,7 +75,8 @@ internal abstract class Message : ICompletable #pragma warning restore CS0618 | CommandFlags.FireAndForget | CommandFlags.NoRedirect - | CommandFlags.NoScriptCache; + | CommandFlags.NoScriptCache + | CommandFlags.ClientCaching; private IResultBox? resultBox; private ResultProcessor? resultProcessor; @@ -197,6 +198,7 @@ public bool IsAdmin } public bool IsAsking => (Flags & AskingFlag) != 0; + public bool IsClientCaching => (Flags & CommandFlags.ClientCaching) != 0; internal bool IsScriptUnavailable => (Flags & ScriptUnavailableFlag) != 0; diff --git a/src/StackExchange.Redis/PhysicalBridge.cs b/src/StackExchange.Redis/PhysicalBridge.cs index 7041cf0af..8394c06e1 100644 --- a/src/StackExchange.Redis/PhysicalBridge.cs +++ b/src/StackExchange.Redis/PhysicalBridge.cs @@ -23,7 +23,9 @@ internal sealed class PhysicalBridge : IDisposable private const double ProfileLogSeconds = (1000 /* ms */ * ProfileLogSamples) / 1000.0; - private static readonly Message ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING); + private static readonly Message + ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING), + ReusableClientCachingYesCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.CACHING, RedisLiterals.yes); private readonly long[] profileLog = new long[ProfileLogSamples]; @@ -1494,13 +1496,13 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne } if (message.IsAsking) { - var asking = ReusableAskingCommand; - connection.EnqueueInsideWriteLock(asking); - asking.WriteTo(connection); - asking.SetRequestSent(); - IncrementOpCount(); + RawWriteInternalMessageInsideWriteLock(connection, ReusableAskingCommand); } } + if (message.IsClientCaching && connection.EnsureServerAssistedClientSideTrackingInsideWriteLock()) + { + RawWriteInternalMessageInsideWriteLock(connection, ReusableClientCachingYesCommand); + } switch (cmd) { case RedisCommand.WATCH: @@ -1570,6 +1572,15 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne } } + internal void RawWriteInternalMessageInsideWriteLock(PhysicalConnection connection, Message message) + { + message.SetInternalCall(); + connection.EnqueueInsideWriteLock(message); + message.WriteTo(connection); + message.SetRequestSent(); + IncrementOpCount(); + } + /// /// For testing only /// @@ -1583,5 +1594,6 @@ internal void SimulateConnectionFailure(SimulatedFailureType failureType) } internal RedisCommand? GetActiveMessage() => Volatile.Read(ref _activeMessage)?.Command; + internal void OnSubscriberFailed() => physical?.OnSubscriberFailed(); } } diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index eb0787606..79926ada1 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -19,6 +19,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using static StackExchange.Redis.Message; +using System.Threading.Channels; namespace StackExchange.Redis { @@ -1599,10 +1600,10 @@ private void MatchResult(in RawResult result) Trace("MESSAGE: " + channel); if (!channel.IsNull) { - if (TryGetPubSubPayload(items[2], out var payload)) + if (TryGetPubSubPayload(items[2], out var payload, out var source)) { _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(channel, channel, payload); + muxer.OnMessage(channel, channel, payload, source); } // could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507 else if (TryGetMultiPubSubPayload(items[2], out var payloads)) @@ -1621,11 +1622,11 @@ private void MatchResult(in RawResult result) Trace("PMESSAGE: " + channel); if (!channel.IsNull) { - if (TryGetPubSubPayload(items[3], out var payload)) + if (TryGetPubSubPayload(items[3], out var payload, out var source)) { var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern); _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(sub, channel, payload); + muxer.OnMessage(sub, channel, payload, source); } else if (TryGetMultiPubSubPayload(items[3], out var payloads)) { @@ -1661,8 +1662,9 @@ private void MatchResult(in RawResult result) _readStatus = ReadStatus.MatchResultComplete; _activeMessage = null; - static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool allowArraySingleton = true) + static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, out RawResult source, bool allowArraySingleton = true) { + source = value; if (value.IsNull) { parsed = RedisValue.Null; @@ -1676,7 +1678,7 @@ static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool parsed = value.AsRedisValue(); return true; case ResultType.MultiBulk when allowArraySingleton && value.ItemsCount == 1: - return TryGetPubSubPayload(in value[0], out parsed, allowArraySingleton: false); + return TryGetPubSubPayload(in value[0], out parsed, out source, allowArraySingleton: false); } parsed = default; return false; @@ -2071,5 +2073,168 @@ internal bool HasPendingCallerFacingItems() if (lockTaken) Monitor.Exit(_writtenAwaitingResponse); } } + + private int _clientTrackingState = (int)ClientTrackingState.NotInitialized; + private enum ClientTrackingState + { + NotInitialized = 0, + ActiveSingleConnectionPerItemTracking = 1, + ActiveSplitConnectionPerItemTracking = 2, + ActiveSingleConnectionBroadcast = 3, + ActiveSplitConnectionBroadcast = 4, + Broken = 10, // was active, now not + } + + private ClientTrackingState GetClientTrackingState() => (ClientTrackingState)Volatile.Read(ref _clientTrackingState); + + /// + /// initializes client caching state and returns True if CLIENT CACHING YES should be sent + /// + internal bool EnsureServerAssistedClientSideTrackingInsideWriteLock() => + GetClientTrackingState() switch + { + ClientTrackingState.ActiveSingleConnectionPerItemTracking => true, + ClientTrackingState.ActiveSplitConnectionPerItemTracking => true, + // don't add CLIENT CACHING per-item when in broadcast mode + ClientTrackingState.ActiveSingleConnectionBroadcast => false, + ClientTrackingState.ActiveSplitConnectionBroadcast => false, + // anything else? slow mode + _ => InitializeServerAssistedClientSideTrackingInsideWriteLock() + }; + + private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() + { + var bridge = BridgeCouldBeNull; + if (bridge is null) + { + return false; // shutting down, be gentle in our nope + } + + var config = bridge.Multiplexer.ClientSideTracking; + if (config is not { IsAlive: true }) + { + return false; // not enabled (should already have faulted, note), or: already dead + } + + ClientTrackingState oldState, newState; + do + { + switch (oldState = GetClientTrackingState()) + { + case ClientTrackingState.ActiveSingleConnectionPerItemTracking: + case ClientTrackingState.ActiveSplitConnectionPerItemTracking: + return true; // we shouldn't be here, but: whatever + case ClientTrackingState.ActiveSingleConnectionBroadcast: + case ClientTrackingState.ActiveSplitConnectionBroadcast: + return false; // we shouldn't be here, but: whatever + case ClientTrackingState.Broken: + bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF)); + oldState = ClientTrackingState.NotInitialized; // ack that we've reset things + Volatile.Write(ref _clientTrackingState, (int)oldState); + goto case ClientTrackingState.NotInitialized; + case ClientTrackingState.NotInitialized: + // note: this check will need to be removed in RESP3 + if (BridgeCouldBeNull?.ServerEndPoint is { SupportsSubscriptions: true } sep + && sep.GetBridge(ConnectionType.Subscription) is { IsConnected: true } sub) + { + var subId = sub.ConnectionId; + if (subId is not null) + { + // subscribe + bridge.Multiplexer.ExecuteSyncImpl(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0); + bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId)); + newState = (config.Options & ClientTrackingOptions.Broadcast) == 0 ? ClientTrackingState.ActiveSplitConnectionPerItemTracking : ClientTrackingState.ActiveSplitConnectionBroadcast; + break; + } + } + return false; // unable to initialize; connections unavailable or similar + default: + return false; // unknown state + } + } while (Interlocked.CompareExchange(ref _clientTrackingState, (int)newState, (int)oldState) != (int)oldState); // redo from start if fighting with OnSubscriberFailed, which is only for the "Broken" scenario + + // we're now in a known state; we only issue CLIENT CACHING YES if we're in per-item tracking mode + return newState is ClientTrackingState.ActiveSingleConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionPerItemTracking; + + } + + internal void OnSubscriberFailed() + { + // if in split connection mode, then: our notifications have failed and we need to reset + if (GetClientTrackingState() is ClientTrackingState.ActiveSplitConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionBroadcast) + { + Volatile.Write(ref _clientTrackingState, (int)ClientTrackingState.Broken); + } + } + + private static readonly Message ReusableSubscribeClientCachingSubscribeMessage = Message.Create( + -1, CommandFlags.FireAndForget, RedisCommand.SUBSCRIBE, ConnectionMultiplexer.ClientCachingChannel); + + private sealed class ClientTrackingMessage : Message + { + private readonly ConnectionMultiplexer.ClientSideTrackingState _state; + private readonly long? _subId; // will be NULL in RESP3 + + public ClientTrackingMessage(ConnectionMultiplexer.ClientSideTrackingState state, long? subId) : base(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT) + { + _state = state; + _subId = subId; + } + + public override int ArgCount + { + get + { + var count = 3; // TRACKING ON {OPTIN|BCAST} + if (_subId is not null) + { + count += 2; // [REDIRECT client-id] + } + if (!_state.Prefixes.IsEmpty) + { + count += _state.Prefixes.Length + 1; // [PREFIX prefix ...] + } + var options = _state.Options; + if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) + { + count++; // [NOLOOP] + } + return count; + } + } + + protected override void WriteImpl(PhysicalConnection physical) + { + physical.WriteHeader(Command, ArgCount); + physical.WriteBulkString(RedisLiterals.TRACKING); + physical.WriteBulkString(RedisLiterals.ON); + if (_subId is not null) + { + physical.WriteBulkString(RedisLiterals.REDIRECT); + physical.WriteBulkString(_subId.GetValueOrDefault()); + } + if (!_state.Prefixes.IsEmpty) + { + physical.WriteBulkString(RedisLiterals.PREFIX); + foreach (ref readonly RedisKey prefix in _state.Prefixes.Span) + { + physical.Write(in prefix); + } + } + var options = _state.Options; + if ((options & ClientTrackingOptions.Broadcast) == 0) + { + physical.WriteBulkString(RedisLiterals.OPTIN); + } + else + { + physical.WriteBulkString(RedisLiterals.BCAST); + } + if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) + { + physical.WriteBulkString(RedisLiterals.NOLOOP); + } + } + } } } diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt index 5f282702b..0e2cf28f5 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt @@ -1 +1,8 @@ - \ No newline at end of file +StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.Broadcast = 1 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.ConcurrentInvalidation = 4 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.None = 0 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.NotifyForOwnCommands = 2 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.CommandFlags.ClientCaching = 1024 -> StackExchange.Redis.CommandFlags +StackExchange.Redis.ConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory prefixes = default(System.ReadOnlyMemory)) -> void +StackExchange.Redis.IConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory prefixes = default(System.ReadOnlyMemory)) -> void \ No newline at end of file diff --git a/src/StackExchange.Redis/RedisLiterals.cs b/src/StackExchange.Redis/RedisLiterals.cs index e926b6da4..f2429b640 100644 --- a/src/StackExchange.Redis/RedisLiterals.cs +++ b/src/StackExchange.Redis/RedisLiterals.cs @@ -50,12 +50,14 @@ public static readonly RedisValue AND = "AND", ANY = "ANY", ASC = "ASC", + BCAST = "BCAST", BEFORE = "BEFORE", BIT = "BIT", BY = "BY", BYLEX = "BYLEX", BYSCORE = "BYSCORE", BYTE = "BYTE", + CACHING = "CACHING", CH = "CH", CHANNELS = "CHANNELS", COPY = "COPY", @@ -97,21 +99,27 @@ public static readonly RedisValue MINMATCHLEN = "MINMATCHLEN", MODULE = "MODULE", NODES = "NODES", + NOLOOP = "NOLOOP", NOSAVE = "NOSAVE", NOT = "NOT", NUMPAT = "NUMPAT", NUMSUB = "NUMSUB", NX = "NX", OBJECT = "OBJECT", + OFF = "OFF", + OPTIN = "OPTIN", OR = "OR", + ON = "ON", PATTERN = "PATTERN", PAUSE = "PAUSE", PERSIST = "PERSIST", PING = "PING", + PREFIX = "PREFIX", PURGE = "PURGE", PX = "PX", PXAT = "PXAT", RANK = "RANK", + REDIRECT = "REDIRECT", REFCOUNT = "REFCOUNT", REPLACE = "REPLACE", RESET = "RESET", @@ -127,6 +135,7 @@ public static readonly RedisValue SKIPME = "SKIPME", STATS = "STATS", STORE = "STORE", + TRACKING = "TRACKING", TYPE = "TYPE", WEIGHTS = "WEIGHTS", WITHMATCHLEN = "WITHMATCHLEN", diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 5a24a716e..067f5c263 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -75,7 +75,7 @@ internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out /// /// Handler that executes whenever a message comes in, this doles out messages to any registered handlers. /// - internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload) + internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload, in RawResult rawPayload) { ICompletable? completable = null; ChannelMessageQueue? queues = null; @@ -91,22 +91,28 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i { CompleteAsWorker(completable); } + if (subscription == ClientCachingChannel) + { + _clientSideTracking?.Write(rawPayload.AsRedisKey()); + } } + internal static readonly RedisChannel ClientCachingChannel = RedisChannel.Literal("__redis__:invalidate"); + internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, Sequence payload) { if (payload.IsSingleSegment) { - foreach (var message in payload.FirstSpan) + foreach (ref readonly RawResult message in payload.FirstSpan) { - OnMessage(subscription, channel, message.AsRedisValue()); + OnMessage(subscription, channel, message.AsRedisValue(), in message); } } else { - foreach (var message in payload) + foreach (ref readonly RawResult message in payload) { - OnMessage(subscription, channel, message.AsRedisValue()); + OnMessage(subscription, channel, message.AsRedisValue(), in message); } } } diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index a90023580..0791d5ab6 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -1032,5 +1032,7 @@ internal bool HasPendingCallerFacingItems() if (interactive?.HasPendingCallerFacingItems() == true) return true; return subscription?.HasPendingCallerFacingItems() ?? false; } + + internal void OnSubscriberFailed() => interactive?.OnSubscriberFailed(); } } diff --git a/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs new file mode 100644 index 000000000..7fc806587 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs @@ -0,0 +1,98 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace StackExchange.Redis.Tests; + +/// +/// Tests for . +/// +[Collection(SharedConnectionFixture.Key)] +public class ClientTrackingTests : TestBase +{ + public ClientTrackingTests(ITestOutputHelper output, SharedConnectionFixture fixture) : base(output, fixture) { } + + [Fact] + public async Task UseFlagWithoutEnabling() + { + using var conn = Create(shared: false); + var key = Me(); + var ex = await Assert.ThrowsAsync( + async () => await conn.GetDatabase().StringGetAsync(key, CommandFlags.ClientCaching) + ); + Assert.Equal("The ClientCaching flag can only be used if EnableServerAssistedClientSideTracking has been called", ex.Message); + } + + [Fact] + public void CallEnableTwice() + { + using var conn = Create(shared: false); + conn.EnableServerAssistedClientSideTracking(key => default); + var ex = Assert.Throws(() => conn.EnableServerAssistedClientSideTracking(key => default)); + Assert.Equal("The EnableServerAssistedClientSideTracking method can be invoked once-only per multiplexer instance", ex.Message); + } + + [Fact] + public void UsePrefixesWithoutBroadcast() + { + using var conn = Create(shared: false); + var ex = Assert.Throws(() => conn.EnableServerAssistedClientSideTracking(key => default, prefixes: new RedisKey[] { "abc" })); + Assert.StartsWith("Prefixes can only be specified when ClientTrackingOptions.Broadcast is used", ex.Message); + Assert.Equal("prefixes", ex.ParamName); + } + + [Theory] + [InlineData(ClientTrackingOptions.None)] + [InlineData(ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + public Task GetNotificationFromOwnConnection(ClientTrackingOptions options) => GetNotification(options, false); + + [Theory] + [InlineData(ClientTrackingOptions.None)] + [InlineData(ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + public Task GetNotificationFromExternalConnection(ClientTrackingOptions options) => GetNotification(options, true); + + private async Task GetNotification(ClientTrackingOptions options, bool externalConnectionMakesChange) + { + bool expectNotification = ((options & ClientTrackingOptions.NotifyForOwnCommands) != 0) || externalConnectionMakesChange; + + using var listen = Create(shared: false); + using var send = externalConnectionMakesChange ? Create() : listen; + + int value = (new Random().Next() % 1024) + 1024, notifyCount = 0; + + var key = Me(); + var db = listen.GetDatabase(); + db.KeyDelete(key); + db.StringSet(key, value); + + listen.EnableServerAssistedClientSideTracking(rkey => + { + if (rkey == key) Interlocked.Increment(ref notifyCount); + return default; + }, options); + + Assert.Equal(value, db.StringGet(key, CommandFlags.ClientCaching)); + Assert.Equal(0, Volatile.Read(ref notifyCount)); + + send.GetDatabase().StringIncrement(key, 5); + await Task.Delay(100); // allow time for the magic to happen + + Assert.Equal(expectNotification ? 1 : 0, Volatile.Read(ref notifyCount)); + Assert.Equal(value + 5, db.StringGet(key, CommandFlags.ClientCaching)); + + } +} diff --git a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs index f61e73e32..78e20a539 100644 --- a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs +++ b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs @@ -185,6 +185,8 @@ public void ExportConfiguration(Stream destination, ExportOptions options = Expo public override string ToString() => _inner.ToString(); long? IInternalConnectionMultiplexer.GetConnectionId(EndPoint endPoint, ConnectionType type) => _inner.GetConnectionId(endPoint, type); + public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) + => _inner.EnableServerAssistedClientSideTracking(keyInvalidated, options, prefixes); } public void Dispose()