diff --git a/src/Dtmworkflow/Workflow.cs b/src/Dtmworkflow/Workflow.cs index 930a02c..b60653a 100644 --- a/src/Dtmworkflow/Workflow.cs +++ b/src/Dtmworkflow/Workflow.cs @@ -17,10 +17,13 @@ public partial class Workflow public virtual WorkflowImp WorkflowImp { get; set; } = new WorkflowImp(); + public Dictionary Context => _context ??= new(); + private readonly IDtmClient _httpClient; private readonly IDtmgRPCClient _grpcClient; private readonly Dtmcli.IBranchBarrierFactory _bbFactory; private readonly ILogger _logger; + private Dictionary _context; public Workflow(IDtmClient httpClient, IDtmgRPCClient grpcClient, Dtmcli.IBranchBarrierFactory bbFactory, ILogger logger) { diff --git a/src/Dtmworkflow/WorkflowGlobalTransaction.cs b/src/Dtmworkflow/WorkflowGlobalTransaction.cs index ce4bbec..99b1e10 100644 --- a/src/Dtmworkflow/WorkflowGlobalTransaction.cs +++ b/src/Dtmworkflow/WorkflowGlobalTransaction.cs @@ -20,6 +20,11 @@ public WorkflowGlobalTransaction(IWorkflowFactory workflowFactory, ILoggerFactor } public async Task Execute(string name, string gid, byte[] data, bool isHttp = true) + { + return await this.Execute(name, gid, data, null, isHttp); + } + + public async Task Execute(string name, string gid, byte[] data, Action wfAction, bool isHttp = true) { if (!this._handlers.TryGetValue(name, out var handler)) { @@ -27,6 +32,8 @@ public async Task Execute(string name, string gid, byte[] data, bool isH } var wf = _workflowFactory.NewWorkflow(name, gid, data, isHttp); + if (wfAction != null) + wfAction(wf); foreach (var fn in handler.Custom) { diff --git a/tests/Dtmgrpc.IntegrationTests/WorkflowGrpcTest.cs b/tests/Dtmgrpc.IntegrationTests/WorkflowGrpcTest.cs index 9d3735e..9b53a58 100644 --- a/tests/Dtmgrpc.IntegrationTests/WorkflowGrpcTest.cs +++ b/tests/Dtmgrpc.IntegrationTests/WorkflowGrpcTest.cs @@ -94,5 +94,32 @@ public async Task Execute_Success() status = await ITTestHelper.GetTranStatus(gid); Assert.Equal("succeed", status); } + + + [Fact] + public async Task ExecuteWithWfAction() + { + var provider = ITTestHelper.AddDtmGrpc(); + var workflowFactory = provider.GetRequiredService(); + var loggerFactory = provider.GetRequiredService(); + WorkflowGlobalTransaction workflowGlobalTransaction = new WorkflowGlobalTransaction(workflowFactory, loggerFactory); + string wfName = Guid.NewGuid().ToString(); + workflowGlobalTransaction.Register(wfName, (workflow, data) => + { + Assert.NotNull(workflow.Context); + Assert.Equal(2, workflow.Context.Count); + Assert.Equal("value1", workflow.Context["key1-string"]); + Assert.Equal(7, workflow.Context["key2-int"]); + + Assert.Equal("input", Encoding.UTF8.GetString(data)); + + return Task.FromResult("output"u8.ToArray()); + }); + await workflowGlobalTransaction.Execute(wfName, Guid.NewGuid().ToString(), "input"u8.ToArray(), workflow => + { + workflow.Context.Add("key1-string", "value1"); + workflow.Context.Add("key2-int", 7); + }); + } } } \ No newline at end of file