From 12bf4bce9db0d69c3d69300147dac860d9eb02fe Mon Sep 17 00:00:00 2001 From: Ansuman Sahoo Date: Tue, 18 Nov 2025 19:07:53 +0530 Subject: [PATCH] vmdriver(vz): Wait for Start() to complete on server side Signed-off-by: Ansuman Sahoo --- pkg/driver/external/client/methods.go | 33 ++++++++++++++++++++++----- pkg/driver/external/driver.pb.go | 5 ++++ pkg/driver/external/driver.proto | 5 ++++ pkg/driver/external/server/methods.go | 11 +++++++++ pkg/driver/external/server/server.go | 3 +++ 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/pkg/driver/external/client/methods.go b/pkg/driver/external/client/methods.go index 536d49b1221..e1e02839b9d 100644 --- a/pkg/driver/external/client/methods.go +++ b/pkg/driver/external/client/methods.go @@ -58,6 +58,9 @@ func (d *DriverClient) CreateDisk(ctx context.Context) error { return nil } +// Start initiates the driver instance and receives streaming responses. It blocks until +// receiving the initial success response, then spawns a goroutine to consume subsequent +// error messages from the stream. Any errors from the driver are sent to the channel. func (d *DriverClient) Start(ctx context.Context) (chan error, error) { d.logger.Debug("Starting driver instance") @@ -67,19 +70,37 @@ func (d *DriverClient) Start(ctx context.Context) (chan error, error) { return nil, err } + // Blocking to receive an initial response to ensure Start() is initiated + // at the server-side. + initialResp, err := stream.Recv() + if err != nil { + d.logger.WithError(err).Error("Error receiving initial response from driver start") + return nil, err + } + if !initialResp.Success { + return nil, errors.New(initialResp.Error) + } + + go func() { + <-ctx.Done() + if closeErr := stream.CloseSend(); closeErr != nil { + d.logger.WithError(closeErr).Warn("Failed to close stream") + } + }() + errCh := make(chan error, 1) go func() { for { - errorStream, err := stream.Recv() + respStream, err := stream.Recv() if err != nil { - d.logger.Errorf("Error receiving response from driver: %v", err) + d.logger.Infof("Error receiving response from driver: %v", err) return } - d.logger.Debugf("Received response: %v", errorStream) - if !errorStream.Success { - errCh <- errors.New(errorStream.Error) + d.logger.Debugf("Received response: %v", respStream) + if !respStream.Success { + errCh <- errors.New(respStream.Error) } else { - errCh <- nil + close(errCh) return } } diff --git a/pkg/driver/external/driver.pb.go b/pkg/driver/external/driver.pb.go index bcef9b8aa4f..27a8cc891b5 100644 --- a/pkg/driver/external/driver.pb.go +++ b/pkg/driver/external/driver.pb.go @@ -154,6 +154,11 @@ func (x *InfoResponse) GetInfoJson() []byte { return nil } +// StartResponse is a streamed response for Start() RPC. It tries to mimic +// errChan from pkg/driver/driver.go. The server sends an initial response +// with success=true when Start() is initiated. If errors occur, they are +// sent as success=false with the error field populated. When the error channel +// closes, a final success=true message is sent. type StartResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` diff --git a/pkg/driver/external/driver.proto b/pkg/driver/external/driver.proto index 9496ae6f945..db57444627f 100644 --- a/pkg/driver/external/driver.proto +++ b/pkg/driver/external/driver.proto @@ -46,6 +46,11 @@ message InfoResponse{ bytes info_json = 1; } +// StartResponse is a streamed response for Start() RPC. It tries to mimic +// errChan from pkg/driver/driver.go. The server sends an initial response +// with success=true when Start() is initiated. If errors occur, they are +// sent as success=false with the error field populated. When the error channel +// closes, a final success=true message is sent. message StartResponse { bool success = 1; string error = 2; diff --git a/pkg/driver/external/server/methods.go b/pkg/driver/external/server/methods.go index aaa3b33e409..0046dbecc5c 100644 --- a/pkg/driver/external/server/methods.go +++ b/pkg/driver/external/server/methods.go @@ -26,9 +26,20 @@ func (s *DriverServer) Start(_ *emptypb.Empty, stream pb.Driver_StartServer) err errChan, err := s.driver.Start(stream.Context()) if err != nil { s.logger.Errorf("Start failed: %v", err) + if sendErr := stream.Send(&pb.StartResponse{Success: false, Error: err.Error()}); sendErr != nil { + s.logger.Errorf("Failed to send error response: %v", sendErr) + return status.Errorf(codes.Internal, "failed to send error response: %v", sendErr) + } return status.Errorf(codes.Internal, "failed to start driver: %v", err) } + // First send a success response upon receiving the errChan to unblock the client + // and start receiving errors (if any). + if err := stream.Send(&pb.StartResponse{Success: true}); err != nil { + s.logger.Errorf("Failed to send success response: %v", err) + return status.Errorf(codes.Internal, "failed to send success response: %v", err) + } + for { select { case err, ok := <-errChan: diff --git a/pkg/driver/external/server/server.go b/pkg/driver/external/server/server.go index bc69ffe8c78..1cd3536d901 100644 --- a/pkg/driver/external/server/server.go +++ b/pkg/driver/external/server/server.go @@ -207,6 +207,9 @@ func handlePreConfiguredDriverAction(ctx context.Context, driver driver.Driver) } } +// Start begins the driver startup process. It sends an initial response to unblock +// the client and then streams subsequent errors(if any), as the driver initializes. +// A final success message is streamed upon successful completion. func Start(extDriver *registry.ExternalDriver, instName string) error { extDriver.Logger.Debugf("Starting external driver at %s", extDriver.Path) if instName == "" {