Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions tests/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package tests
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"

Expand All @@ -24,7 +22,6 @@ import (
"go.temporal.io/sdk/worker"
"go.temporal.io/sdk/workflow"
"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/common/testing/freeport"
"go.temporal.io/server/common/testing/protoassert"
"go.temporal.io/server/common/testing/protorequire"
"go.temporal.io/server/common/testing/testvars"
Expand Down Expand Up @@ -53,27 +50,13 @@ func TestCallbacksSuite(t *testing.T) {
suite.Run(t, new(CallbacksSuite))
}

func (s *CallbacksSuite) runNexusCompletionHTTPServer(t *testing.T, h *completionHandler, listenAddr string) {
func (s *CallbacksSuite) runNexusCompletionHTTPServer(t *testing.T, h *completionHandler) string {
hh := nexus.NewCompletionHTTPHandler(nexus.CompletionHandlerOptions{Handler: h})
srv := &http.Server{Addr: listenAddr, Handler: hh}
listener, err := net.Listen("tcp", listenAddr)
s.NoError(err)

errCh := make(chan error, 1)
go func() {
errCh <- srv.Serve(listener)
}()

srv := httptest.NewServer(hh)
t.Cleanup(func() {
// Graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
if ctx.Err() != nil {
require.NoError(t, err)
require.ErrorIs(t, <-errCh, http.ErrServerClosed)
}
srv.Close()
})
return srv.URL
}

func (s *CallbacksSuite) TestWorkflowCallbacks_InvalidArgument() {
Expand Down Expand Up @@ -243,8 +226,11 @@ func (s *CallbacksSuite) TestWorkflowNexusCallbacks_CarriedOver() {
requestCh: make(chan *nexus.CompletionRequest, 2),
requestCompleteCh: make(chan error, 2),
}
callbackAddress := fmt.Sprintf("localhost:%d", freeport.MustGetFreePort())
s.runNexusCompletionHTTPServer(s.T(), ch, callbackAddress)
defer func() {
close(ch.requestCh)
close(ch.requestCompleteCh)
}()
callbackAddress := s.runNexusCompletionHTTPServer(s.T(), ch)

w := worker.New(sdkClient, taskQueue, worker.Options{})
w.RegisterWorkflowWithOptions(tc.wf, workflow.RegisterOptions{Name: workflowType})
Expand Down Expand Up @@ -276,15 +262,15 @@ func (s *CallbacksSuite) TestWorkflowNexusCallbacks_CarriedOver() {
{
Variant: &commonpb.Callback_Nexus_{
Nexus: &commonpb.Callback_Nexus{
Url: "http://" + callbackAddress + "/cb1",
Url: callbackAddress + "/cb1",
},
},
Links: []*commonpb.Link{links[0]},
},
{
Variant: &commonpb.Callback_Nexus_{
Nexus: &commonpb.Callback_Nexus{
Url: "http://" + callbackAddress + "/cb2",
Url: callbackAddress + "/cb2",
},
},
Links: []*commonpb.Link{links[1]},
Expand Down Expand Up @@ -449,8 +435,11 @@ func (s *CallbacksSuite) TestNexusResetWorkflowWithCallback() {
requestCh: make(chan *nexus.CompletionRequest, 2),
requestCompleteCh: make(chan error, 2),
}
callbackAddress := fmt.Sprintf("localhost:%d", freeport.MustGetFreePort())
s.runNexusCompletionHTTPServer(s.T(), ch, callbackAddress)
defer func() {
close(ch.requestCh)
close(ch.requestCompleteCh)
}()
callbackAddress := s.runNexusCompletionHTTPServer(s.T(), ch)

w := worker.New(sdkClient, taskQueue.GetName(), worker.Options{})

Expand All @@ -473,14 +462,14 @@ func (s *CallbacksSuite) TestNexusResetWorkflowWithCallback() {
{
Variant: &commonpb.Callback_Nexus_{
Nexus: &commonpb.Callback_Nexus{
Url: "http://" + callbackAddress + "/cb1",
Url: callbackAddress + "/cb1",
},
},
},
{
Variant: &commonpb.Callback_Nexus_{
Nexus: &commonpb.Callback_Nexus{
Url: "http://" + callbackAddress + "/cb2",
Url: callbackAddress + "/cb2",
},
},
},
Expand Down Expand Up @@ -571,9 +560,17 @@ func (s *CallbacksSuite) TestNexusResetWorkflowWithCallback() {
s.NoError(err)

for range cbs {
completion := <-ch.requestCh
s.Equal(nexus.OperationStateSucceeded, completion.State)
ch.requestCompleteCh <- nil
select {
case completion := <-ch.requestCh:
s.Equal(nexus.OperationStateSucceeded, completion.State)
case <-time.After(time.Second):
s.Fail("timeout waiting for callback")
}
select {
case ch.requestCompleteCh <- nil:
case <-time.After(time.Second):
s.Fail("timeout writing to completion channel")
}
}

s.EventuallyWithT(
Expand Down Expand Up @@ -635,8 +632,11 @@ func (s *CallbacksSuite) TestNexusResetWorkflowWithCallback_ResetToNotBaseRun()
requestCh: make(chan *nexus.CompletionRequest, 1),
requestCompleteCh: make(chan error, 1),
}
callbackAddress := fmt.Sprintf("localhost:%d", freeport.MustGetFreePort())
s.runNexusCompletionHTTPServer(s.T(), ch, callbackAddress)
defer func() {
close(ch.requestCh)
close(ch.requestCompleteCh)
}()
callbackAddress := s.runNexusCompletionHTTPServer(s.T(), ch)

w := worker.New(sdkClient, taskQueue.GetName(), worker.Options{})

Expand Down Expand Up @@ -683,7 +683,7 @@ func (s *CallbacksSuite) TestNexusResetWorkflowWithCallback_ResetToNotBaseRun()

// 2. Start WF second time w/ callbacks (new run)
cbs := []*commonpb.Callback{
{Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://" + callbackAddress + "/cb1"}}},
{Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress + "/cb1"}}},
}

request2 := proto.Clone(request1).(*workflowservice.StartWorkflowExecutionRequest)
Expand Down
7 changes: 6 additions & 1 deletion tests/max_buffered_event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tests

import (
"context"
"crypto/rand"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -175,14 +176,18 @@ func (s *MaxBufferedEventSuite) TestBufferedEventsMutableStateSizeLimit() {

// now send 3 signals with 500KB payload each, all of them will be buffered
buf := make([]byte, 500*1024)
// fill the slice with random data to make sure the
// encoder does not zero out the data
_, err := rand.Read(buf)
s.NoError(err)
largePayload := payloads.EncodeBytes(buf)
for i := 0; i < 3; i++ {
err := s.SdkClient().SignalWorkflow(testCtx, wid, "", "test-signal", largePayload)
s.NoError(err)
}

// send 4th signal, this will fail the started workflow task and force terminate the workflow
err := s.SdkClient().SignalWorkflow(testCtx, wid, "", "test-signal", largePayload)
err = s.SdkClient().SignalWorkflow(testCtx, wid, "", "test-signal", largePayload)
s.NoError(err)

// unblock goroutine that runs local activity
Expand Down
Loading