diff --git a/cmd/tools/grpc_test_server/main.go b/cmd/tools/grpc_test_server/main.go new file mode 100644 index 000000000..86a497c7a --- /dev/null +++ b/cmd/tools/grpc_test_server/main.go @@ -0,0 +1,144 @@ +package main + +import ( + "context" + "io" + "log" + "net" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/reflection" + + structpb "google.golang.org/protobuf/types/known/structpb" +) + +// Manual registration for bidi service using google.protobuf.Struct +func chatStreamHandler(_ interface{}, stream grpc.ServerStream) error { + for { + in := new(structpb.Struct) + if err := stream.RecvMsg(in); err != nil { + if err == io.EOF { + return nil + } + return err + } + // Enrich response with server metadata and timestamp so client logs are visible + enriched := map[string]interface{}{} + for k, v := range in.AsMap() { + enriched[k] = v + } + enriched["server"] = "grpc_test_server" + enriched["ts"] = time.Now().Format(time.RFC3339Nano) + enriched["note"] = "bidi echo" + out, _ := structpb.NewStruct(enriched) + if err := stream.SendMsg(out); err != nil { + return err + } + } +} + +var chatServiceDesc = grpc.ServiceDesc{ + ServiceName: "chat.Chat", + HandlerType: (*interface{})(nil), + Streams: []grpc.StreamDesc{{ + StreamName: "Stream", + Handler: chatStreamHandler, + ServerStreams: true, + ClientStreams: true, + }}, +} + +// Server-stream-only echo service using Struct +func echoStreamHandler(_ interface{}, stream grpc.ServerStream) error { + in := new(structpb.Struct) + if err := stream.RecvMsg(in); err != nil { + if err == io.EOF { + return nil + } + return err + } + // emit a few echoes and finish + for i := 0; i < 3; i++ { + if err := stream.SendMsg(in); err != nil { + return err + } + time.Sleep(200 * time.Millisecond) + } + return nil +} + +var echoServiceDesc = grpc.ServiceDesc{ + ServiceName: "echo.Echo", + HandlerType: (*interface{})(nil), + Streams: []grpc.StreamDesc{{ + StreamName: "Stream", + Handler: echoStreamHandler, + ServerStreams: true, + ClientStreams: false, + }}, +} + +// Client-stream-only ingest service: counts messages and returns a final result +func ingestStreamHandler(_ interface{}, stream grpc.ServerStream) error { + count := 0 + for { + in := new(structpb.Struct) + if err := stream.RecvMsg(in); err != nil { + if err == io.EOF { + break + } + return err + } + count++ + } + // respond once with count + out, _ := structpb.NewStruct(map[string]interface{}{"count": count}) + return stream.SendMsg(out) +} + +var ingestServiceDesc = grpc.ServiceDesc{ + ServiceName: "ingest.Ingest", + HandlerType: (*interface{})(nil), + Streams: []grpc.StreamDesc{{ + StreamName: "Stream", + Handler: ingestStreamHandler, + ServerStreams: false, + ClientStreams: true, + }}, +} + +func main() { + lis, err := net.Listen("tcp", ":50051") + if err != nil { + log.Fatalf("listen: %v", err) + } + s := grpc.NewServer() + + // Health service + hs := health.NewServer() + // Set default service to SERVING + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + + healthpb.RegisterHealthServer(s, hs) + + // Reflection for dynamic clients + reflection.Register(s) + + // Register manual chat service + s.RegisterService(&chatServiceDesc, nil) + // Register echo server-stream-only service + s.RegisterService(&echoServiceDesc, nil) + // Register ingest client-stream-only service + s.RegisterService(&ingestServiceDesc, nil) + + log.Println("grpc test server listening on :50051") + if err := s.Serve(lis); err != nil { + log.Fatalf("serve: %v", err) + } +} + +// Ensure unused import of context isn't optimized out in newer toolchains +var _ = context.Background diff --git a/cmd/tools/grpc_test_server/pb/README.md b/cmd/tools/grpc_test_server/pb/README.md new file mode 100644 index 000000000..85e3e75df --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/README.md @@ -0,0 +1,5 @@ +Generate Go bindings (requires protoc and protoc-gen-go, protoc-gen-go-grpc): + +protoc --go_out=. --go-grpc_out=. chat.proto + + diff --git a/cmd/tools/grpc_test_server/pb/chat.proto b/cmd/tools/grpc_test_server/pb/chat.proto new file mode 100644 index 000000000..85ad149a5 --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/chat.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; +package chat; + +import "google/protobuf/struct.proto"; + +service Chat { + rpc Stream(stream google.protobuf.Struct) returns (stream google.protobuf.Struct); +} + + diff --git a/cmd/tools/grpc_test_server/pb/echo.proto b/cmd/tools/grpc_test_server/pb/echo.proto new file mode 100644 index 000000000..fb45c8851 --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/echo.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package echo; + +import "google/protobuf/struct.proto"; + +service Echo { + rpc Stream(google.protobuf.Struct) returns (stream google.protobuf.Struct); +} \ No newline at end of file diff --git a/cmd/tools/grpc_test_server/pb/google/protobuf/struct.proto b/cmd/tools/grpc_test_server/pb/google/protobuf/struct.proto new file mode 100644 index 000000000..3abde3fbf --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/google/protobuf/struct.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; +package google.protobuf; + +option go_package = "google.golang.org/protobuf/types/known/structpb"; + +message Struct { + map fields = 1; +} + +message Value { + oneof kind { + NullValue null_value = 1; + double number_value = 2; + string string_value = 3; + bool bool_value = 4; + Struct struct_value = 5; + ListValue list_value = 6; + } +} + +enum NullValue { + NULL_VALUE = 0; +} + +message ListValue { + repeated Value values = 1; +} + + diff --git a/cmd/tools/grpc_test_server/pb/grpc/health/v1/health.proto b/cmd/tools/grpc_test_server/pb/grpc/health/v1/health.proto new file mode 100644 index 000000000..c582d43aa --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/grpc/health/v1/health.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package grpc.health.v1; + +// Standard gRPC health checking protocol +service Health { + rpc Check(HealthCheckRequest) returns (HealthCheckResponse); +} + +message HealthCheckRequest { + string service = 1; +} + +enum HealthCheckResponse_ServingStatus { + UNKNOWN = 0; + SERVING = 1; + NOT_SERVING = 2; + SERVICE_UNKNOWN = 3; +} + +message HealthCheckResponse { + HealthCheckResponse_ServingStatus status = 1; +} + + diff --git a/cmd/tools/grpc_test_server/pb/ingest.proto b/cmd/tools/grpc_test_server/pb/ingest.proto new file mode 100644 index 000000000..9d39ad186 --- /dev/null +++ b/cmd/tools/grpc_test_server/pb/ingest.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package ingest; + +import "google/protobuf/struct.proto"; + +service Ingest { + rpc Stream(stream google.protobuf.Struct) returns (google.protobuf.Struct); +} \ No newline at end of file diff --git a/config/examples/grpc_bidi_chat.yaml b/config/examples/grpc_bidi_chat.yaml new file mode 100644 index 000000000..c21a66124 --- /dev/null +++ b/config/examples/grpc_bidi_chat.yaml @@ -0,0 +1,38 @@ +input: + label: "" + generate: + interval: 1s + mapping: | + root = { + "session_id": "demo", + "message": "hello" + } + +pipeline: + processors: [] + +output: + label: "" + grpc_client: + # minimal required fields + address: "127.0.0.1:50051" + method: "/chat.Chat/Stream" + rpc_type: "bidi" + proto_files: + - chat.proto + - google/protobuf/struct.proto + include_paths: + - cmd/tools/grpc_test_server/pb + # optional timeouts + call_timeout: "30s" + connect_timeout: "2s" + # optional retries + retry_max_attempts: 3 + retry_initial_backoff: "200ms" + retry_max_backoff: "2s" + retry_backoff_multiplier: 2 + # optional auth headers + # bearer_token: "your-token-here" + # auth_headers: + # x-api-key: "demo" + diff --git a/config/examples/grpc_client_stream.yaml b/config/examples/grpc_client_stream.yaml new file mode 100644 index 000000000..586382602 --- /dev/null +++ b/config/examples/grpc_client_stream.yaml @@ -0,0 +1,29 @@ +input: + label: "" + generate: + count: 3 + interval: 200ms + mapping: | + root = {} + +pipeline: + processors: [] + +output: + label: "" + grpc_client: + address: "127.0.0.1:50051" + method: "/ingest.Ingest/Stream" + rpc_type: "client_stream" + proto_files: + - ingest.proto + - google/protobuf/struct.proto + include_paths: + - cmd/tools/grpc_test_server/pb + call_timeout: "30s" + connect_timeout: "2s" + retry_max_attempts: 2 + retry_initial_backoff: "100ms" + retry_max_backoff: "1s" + retry_backoff_multiplier: 2 + diff --git a/config/examples/grpc_enhanced_test.yaml b/config/examples/grpc_enhanced_test.yaml new file mode 100644 index 000000000..bf1ce4cb3 --- /dev/null +++ b/config/examples/grpc_enhanced_test.yaml @@ -0,0 +1,32 @@ +input: + label: "" + generate: + count: 5 + interval: 500ms + mapping: | + root = { + "value": this.count + } + +pipeline: + processors: [] + +output: + label: "" + grpc_client: + address: "127.0.0.1:50051" + method: "/echo.Echo/Stream" + rpc_type: "unary" + # For outputs, provide the body via the input/generate mapping + proto_files: + - echo.proto + - google/protobuf/struct.proto + include_paths: + - cmd/tools/grpc_test_server/pb + call_timeout: "5s" + connect_timeout: "2s" + retry_max_attempts: 3 + retry_initial_backoff: "200ms" + retry_max_backoff: "1s" + retry_backoff_multiplier: 2 + diff --git a/config/examples/grpc_insecure_example.yaml b/config/examples/grpc_insecure_example.yaml new file mode 100644 index 000000000..7875e2f3e --- /dev/null +++ b/config/examples/grpc_insecure_example.yaml @@ -0,0 +1,35 @@ +input: + generate: + count: 3 + interval: 2s + mapping: 'root = { "request_id": uuid_v4(), "timestamp": now() }' + +output: + type: grpc_client + grpc_client: + rpc_type: client_stream + address: 127.0.0.1:50051 + method: /ingest.Ingest/Stream + + # Example using insecure transport for local testing + proto_files: [ "ingest.proto", "google/protobuf/struct.proto" ] + include_paths: [ "cmd/tools/grpc_test_server/pb" ] + + # Timeouts + call_timeout: 10s + connect_timeout: 2s + + # gRPC best practices + propagate_deadlines: true + + # Robust retry policy + retry_max_attempts: 3 + retry_initial_backoff: 1s + retry_max_backoff: 10s + retry_backoff_multiplier: 2.0 + + # Authentication headers (optional) + auth_headers: + client-type: bento-insecure + environment: local + trace-enabled: "true" diff --git a/config/examples/grpc_server_stream.yaml b/config/examples/grpc_server_stream.yaml new file mode 100644 index 000000000..2c4266e4b --- /dev/null +++ b/config/examples/grpc_server_stream.yaml @@ -0,0 +1,26 @@ +input: + label: "" + grpc_client: + address: "127.0.0.1:50051" + method: "/echo.Echo/Stream" + rpc_type: "server_stream" + request_json: "{}" + proto_files: + - echo.proto + - google/protobuf/struct.proto + include_paths: + - cmd/tools/grpc_test_server/pb + call_timeout: "10s" + connect_timeout: "2s" + retry_max_attempts: 2 + retry_initial_backoff: "100ms" + retry_max_backoff: "1s" + retry_backoff_multiplier: 2 + +pipeline: + processors: [] + +output: + label: "" + stdout: {} + diff --git a/config/examples/grpc_unary_healthcheck.yaml b/config/examples/grpc_unary_healthcheck.yaml new file mode 100644 index 000000000..a0ca0d846 --- /dev/null +++ b/config/examples/grpc_unary_healthcheck.yaml @@ -0,0 +1,25 @@ +input: + label: "" + generate: + count: 1 + mapping: "root = {}" + +pipeline: + processors: [] + +output: + label: "" + grpc_client: + address: "127.0.0.1:50051" + method: "/grpc.health.v1.Health/Check" + rpc_type: "unary" + proto_files: + - grpc/health/v1/health.proto + include_paths: + - cmd/tools/grpc_test_server/pb + # For unary requests set the input.request_json + # request_json here is not supported for outputs in minimal mode + + call_timeout: "5s" + connect_timeout: "2s" + diff --git a/internal/impl/grpc_client/common.go b/internal/impl/grpc_client/common.go new file mode 100644 index 000000000..8e0c15c07 --- /dev/null +++ b/internal/impl/grpc_client/common.go @@ -0,0 +1,1242 @@ +package grpc_client + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strings" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/desc/protoparse" + "github.com/jhump/protoreflect/grpcreflect" + + "github.com/warpstreamlabs/bento/public/service" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// Common config field names +const ( + fieldAddress = "address" + fieldMethod = "method" + fieldRPCType = "rpc_type" + fieldRequestJSON = "request_json" +) + +// createBaseConfigSpec creates the common configuration fields shared between input and output +func createBaseConfigSpec() *service.ConfigSpec { + spec := service.NewConfigSpec(). + Version("1.11.0"). + Categories("Services"). + Field(service.NewStringField(fieldAddress).Default("127.0.0.1:50051")). + Field(service.NewStringField(fieldMethod).Description("Full method name, e.g. /pkg.Service/Method")). + // Auth + Field(service.NewStringField("bearer_token").Secret().Optional()). + Field(service.NewStringMapField("auth_headers").Optional()). + // Core timing + Field(service.NewDurationField("call_timeout").Default("0s")). + Field(service.NewDurationField("connect_timeout").Default("0s")). + // Reflection / proto + Field(service.NewStringListField("proto_files").Optional()). + Field(service.NewStringListField("include_paths").Optional()). + // Best practices / minimal behavior + Field(service.NewBoolField("propagate_deadlines").Default(true)). + // Retry policy (simple) + Field(service.NewIntField("retry_max_attempts").Default(0)). + Field(service.NewDurationField("retry_initial_backoff").Default("1s")). + Field(service.NewDurationField("retry_max_backoff").Default("30s")). + Field(service.NewFloatField("retry_backoff_multiplier").Default(2.0)). + // Logging levels + Field(service.NewStringField("log_level_success").Default("debug")). + Field(service.NewStringField("log_level_error").Default("debug")) + + return spec +} + +// enhanceCallContext enhances the context for gRPC calls; simplified to no-op +func enhanceCallContext(ctx context.Context, _ *Config, _ func(context.Context) context.Context) context.Context { + return ctx +} + +// injectMetadataIntoContext adds auth headers to the gRPC context (simplified) +func injectMetadataIntoContext(ctx context.Context, cfg *Config) context.Context { + if cfg == nil { + return ctx + } + md := metadata.MD{} + if len(cfg.AuthHeaders) > 0 { + for k, v := range cfg.AuthHeaders { + md.Set(strings.ToLower(k), v) + } + } + if cfg.BearerToken != "" { + md.Set("authorization", "Bearer "+cfg.BearerToken) + } + if len(md) == 0 { + return ctx + } + if existing, ok := metadata.FromOutgoingContext(ctx); ok { + merged := existing.Copy() + for k, v := range md { + for _, vv := range v { + merged.Append(k, vv) + } + } + return metadata.NewOutgoingContext(ctx, merged) + } + return metadata.NewOutgoingContext(ctx, md) +} + +// Default timing configuration constants +const ( + defaultRetryBackoffInitial = time.Second + defaultRetryBackoffMax = 30 * time.Second + defaultConnectionIdleTimeout = 30 * time.Minute + defaultSessionSweepInterval = time.Minute + defaultConnectionPoolSize = 1 + defaultRetryMultiplier = 2.0 + defaultCleanupTickerInterval = time.Minute + defaultHealthCheckInterval = 30 * time.Second + defaultMaxConnectionFailures = 3 + defaultFailureWindow = 5 * time.Minute +) + +// Magic numbers for message sizes and limits +const ( + minMethodNameLength = 3 // Minimum: "/a/b" +) + +// Config represents shared gRPC client configuration +type Config struct { + Address string + Method string + RPCType string + RequestJSON string + BearerToken string + AuthHeaders map[string]string + CallTimeout time.Duration + ConnectTimeout time.Duration + ProtoFiles []string + IncludePaths []string + + // gRPC best practices + RetryPolicy *RetryPolicy + + // Logging levels + LogLevelSuccess string + LogLevelError string + + // Optional call observer for outcomes + Observer CallObserver + + // Logger for shared components + Logger *service.Logger +} + +// CallObserver receives outcomes of calls +type CallObserver interface { + RecordCall(err error) +} + +// RetryPolicy defines retry behavior for gRPC calls +type RetryPolicy struct { + MaxAttempts int + InitialBackoff time.Duration + MaxBackoff time.Duration + BackoffMultiplier float64 + RetryableStatusCodes []codes.Code +} + +// Note: advanced gRPC service configs are not used in minimal mode + +// ParseConfigFromService extracts gRPC configuration from service config +func ParseConfigFromService(conf *service.ParsedConfig) (*Config, error) { + cfg := &Config{} + extractCoreConfig(conf, cfg) + extractAuthConfig(conf, cfg) + extractConnectionConfig(conf, cfg) + extractBestPracticesConfig(conf, cfg) + extractRetryPolicyConfig(conf, cfg) + return cfg, nil +} + +// extractCoreConfig extracts fundamental gRPC configuration fields +func extractCoreConfig(conf *service.ParsedConfig, cfg *Config) { + cfg.Address, _ = conf.FieldString(fieldAddress) + cfg.Method, _ = conf.FieldString(fieldMethod) + cfg.RPCType, _ = conf.FieldString(fieldRPCType) + cfg.RequestJSON, _ = conf.FieldString(fieldRequestJSON) +} + +// extractAuthConfig extracts authentication configuration +func extractAuthConfig(conf *service.ParsedConfig, cfg *Config) { + cfg.BearerToken, _ = conf.FieldString("bearer_token") + cfg.AuthHeaders, _ = conf.FieldStringMap("auth_headers") +} + +// extractConnectionConfig extracts connection-related configuration (simplified) +func extractConnectionConfig(conf *service.ParsedConfig, cfg *Config) { + cfg.CallTimeout, _ = conf.FieldDuration("call_timeout") + cfg.ConnectTimeout, _ = conf.FieldDuration("connect_timeout") + cfg.ProtoFiles, _ = conf.FieldStringList("proto_files") + cfg.IncludePaths, _ = conf.FieldStringList("include_paths") +} + +// extractBestPracticesConfig extracts minimal logging config +func extractBestPracticesConfig(conf *service.ParsedConfig, cfg *Config) { + cfg.LogLevelSuccess, _ = conf.FieldString("log_level_success") + cfg.LogLevelError, _ = conf.FieldString("log_level_error") +} + +// extractRetryPolicyConfig extracts retry policy configuration +func extractRetryPolicyConfig(conf *service.ParsedConfig, cfg *Config) { + maxAttempts, _ := conf.FieldInt("retry_max_attempts") + if maxAttempts <= 0 { + return + } + retryInitialBackoff, _ := conf.FieldDuration("retry_initial_backoff") + if retryInitialBackoff <= 0 { + retryInitialBackoff = defaultRetryBackoffInitial + } + retryMaxBackoff, _ := conf.FieldDuration("retry_max_backoff") + if retryMaxBackoff <= 0 { + retryMaxBackoff = defaultRetryBackoffMax + } + retryMultiplier := defaultRetryMultiplier + if multiplier, _ := conf.FieldFloat("retry_backoff_multiplier"); multiplier > 0 { + retryMultiplier = multiplier + } + cfg.RetryPolicy = &RetryPolicy{ + MaxAttempts: maxAttempts, + InitialBackoff: retryInitialBackoff, + MaxBackoff: retryMaxBackoff, + BackoffMultiplier: retryMultiplier, + RetryableStatusCodes: []codes.Code{ + codes.Unavailable, + codes.ResourceExhausted, + codes.Aborted, + codes.DeadlineExceeded, + }, + } +} + +// createConnection creates a single gRPC connection with minimal options +func createConnection(ctx context.Context, cfg *Config) (*grpc.ClientConn, error) { + opts, err := buildDialOptions(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to build dial options: %w", err) + } + conn, err := grpc.NewClient(cfg.Address, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create gRPC client: %w", err) + } + if cfg.ConnectTimeout > 0 { + ctxWait, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) + defer cancel() + conn.Connect() + for { + st := conn.GetState() + if st == connectivity.Ready { + break + } + if !conn.WaitForStateChange(ctxWait, st) { + _ = conn.Close() + if err := ctxWait.Err(); err != nil { + return nil, err + } + return nil, errors.New("connection not ready within connect_timeout") + } + } + } + return conn, nil +} + +// headerCreds implements credentials.PerRPCCredentials with enhanced security +type headerCreds struct { + token string + headers map[string]string + tlsEnabled bool +} + +func (h headerCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + md := make(map[string]string) + if h.token != "" { + md["authorization"] = "Bearer " + h.token + } + for k, v := range h.headers { + if err := validateHeaderValue(v); err != nil { + continue + } + md[strings.ToLower(k)] = v + } + return md, nil +} + +// validateHeaderValue validates that a header value is safe +func validateHeaderValue(value string) error { + if len(value) > 4096 { + return errors.New("header value is too long (maximum 4096 characters)") + } + for _, char := range value { + if char < 32 && char != '\t' { + return errors.New("header value contains control character") + } + if char == '\r' || char == '\n' { + return errors.New("header value contains CRLF characters (potential header injection)") + } + } + return nil +} + +func (h headerCreds) RequireTransportSecurity() bool { return h.tlsEnabled } + +// CircuitBreakerState represents the state of the circuit breaker +type CircuitBreakerState int + +const ( + CircuitBreakerClosed CircuitBreakerState = iota // Normal operation + CircuitBreakerOpen // Failing, reject requests + CircuitBreakerHalfOpen // Testing if service recovered +) + +// String returns the string representation of the circuit breaker state +func (s CircuitBreakerState) String() string { + switch s { + case CircuitBreakerClosed: + return "closed" + case CircuitBreakerOpen: + return "open" + case CircuitBreakerHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreaker implements the circuit breaker pattern for connection management +type CircuitBreaker struct { + state CircuitBreakerState + failures int + lastFailure time.Time + nextAttempt time.Time + mu sync.RWMutex + + // Configuration + failureThreshold int // Number of failures before opening circuit + resetTimeout time.Duration // Time to wait before trying again + halfOpenMaxReqs int // Max requests allowed in half-open state + halfOpenCount int // Current requests in half-open state +} + +// NewCircuitBreaker creates a new circuit breaker +func NewCircuitBreaker(failureThreshold int, resetTimeout time.Duration) *CircuitBreaker { + return &CircuitBreaker{ + state: CircuitBreakerClosed, + failureThreshold: failureThreshold, + resetTimeout: resetTimeout, + halfOpenMaxReqs: 3, // Allow 3 test requests in half-open state + } +} + +// CanExecute checks if a request can be executed based on circuit breaker state +func (cb *CircuitBreaker) CanExecute() bool { + cb.mu.RLock() + defer cb.mu.RUnlock() + + switch cb.state { + case CircuitBreakerClosed: + return true + case CircuitBreakerOpen: + if time.Now().After(cb.nextAttempt) { + cb.mu.RUnlock() + cb.mu.Lock() + cb.state = CircuitBreakerHalfOpen + cb.halfOpenCount = 0 + cb.mu.Unlock() + cb.mu.RLock() + return true + } + return false + case CircuitBreakerHalfOpen: + return cb.halfOpenCount < cb.halfOpenMaxReqs + default: + return false + } +} + +// RecordSuccess records a successful request +func (cb *CircuitBreaker) RecordSuccess() { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitBreakerHalfOpen: + cb.halfOpenCount++ + if cb.halfOpenCount >= cb.halfOpenMaxReqs { + // Transition back to closed state + cb.state = CircuitBreakerClosed + cb.failures = 0 + cb.lastFailure = time.Time{} + } + case CircuitBreakerClosed: + // Reset failure count on success + cb.failures = 0 + cb.lastFailure = time.Time{} + } +} + +// RecordFailure records a failed request +func (cb *CircuitBreaker) RecordFailure() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.failures++ + cb.lastFailure = time.Now() + + switch cb.state { + case CircuitBreakerHalfOpen: + // Half-open failure, go back to open + cb.state = CircuitBreakerOpen + cb.nextAttempt = time.Now().Add(cb.resetTimeout) + case CircuitBreakerClosed: + if cb.failures >= cb.failureThreshold { + cb.state = CircuitBreakerOpen + cb.nextAttempt = time.Now().Add(cb.resetTimeout) + } + } +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() CircuitBreakerState { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} + +// GetStats returns circuit breaker statistics +func (cb *CircuitBreaker) GetStats() map[string]interface{} { + cb.mu.RLock() + defer cb.mu.RUnlock() + + return map[string]interface{}{ + "state": cb.state.String(), + "failures": cb.failures, + "last_failure": cb.lastFailure.Format(time.RFC3339), + "next_attempt": cb.nextAttempt.Format(time.RFC3339), + "half_open_count": cb.halfOpenCount, + "failure_threshold": cb.failureThreshold, + "reset_timeout": cb.resetTimeout.String(), + } +} + +// ConnectionManager manages gRPC connections with proper lifecycle and pooling. +// +// The manager coordinates between connection pooling, automatic cleanup, and +// graceful shutdown. It runs a background goroutine that periodically checks +// for idle connections and replaces them to maintain connection freshness. +// +// Key Responsibilities: +// - Connection pool lifecycle management +// - Automatic cleanup of idle connections +// - Thread-safe access coordination +// - Graceful shutdown without resource leaks +// - Circuit breaker pattern integration +type ConnectionManager struct { + pool *ConnectionPool // Underlying connection pool + circuitBreaker *CircuitBreaker // Circuit breaker for fault tolerance + mu sync.RWMutex // Protects manager state during shutdown + closed bool // Indicates if manager is shut down +} + +// NewConnectionManager creates a new connection manager with pooling support +func NewConnectionManager(ctx context.Context, cfg *Config) (*ConnectionManager, error) { + pool := &ConnectionPool{ + connections: make([]connectionEntry, 0, 1), + cfg: cfg, + ctx: ctx, + } + // Circuit breaker with fixed sensible defaults + circuitBreaker := NewCircuitBreaker(5, 30*time.Second) + // Startup warnings + if cfg.Logger != nil { + if cfg.BearerToken != "" { + cfg.Logger.Warnf("Using bearer_token over insecure transport. Avoid sending credentials without TLS.") + } + for k := range cfg.AuthHeaders { + kl := strings.ToLower(k) + if strings.Contains(kl, "password") || strings.Contains(kl, "secret") || strings.Contains(kl, "token") || strings.Contains(kl, "key") { + cfg.Logger.Warnf("Auth header key '%s' may contain sensitive data. Ensure secure transport.", k) + } + } + } + for i := 0; i < 1; i++ { + conn, err := createConnection(ctx, cfg) + if err != nil { + pool.closeAllConnections() + return nil, fmt.Errorf("failed to create connection %d: %w", i, err) + } + now := time.Now() + pool.connections = append(pool.connections, connectionEntry{conn: conn, lastUsed: now, createdAt: now, healthChecked: now}) + } + cm := &ConnectionManager{pool: pool, circuitBreaker: circuitBreaker} + go cm.cleanupIdleConnections() + return cm, nil +} + +// createConnection unchanged except uses cfg.ConnectTimeout + +// GetConnection returns an available gRPC connection from the pool (thread-safe) +func (cm *ConnectionManager) GetConnection() (*grpc.ClientConn, error) { + cm.mu.RLock() + defer cm.mu.RUnlock() + + if cm.closed { + return nil, errors.New("connection manager is closed") + } + + // Check circuit breaker before attempting to get connection + if !cm.circuitBreaker.CanExecute() { + return nil, errors.New("circuit breaker is open - service unavailable") + } + + conn, err := cm.pool.getConnection() + if err != nil { + // Record failure in circuit breaker + cm.circuitBreaker.RecordFailure() + return nil, err + } + + // Record success in circuit breaker + cm.circuitBreaker.RecordSuccess() + return conn, nil +} + +// ValidateConnection checks if a specific connection is healthy and ready for use +func (cm *ConnectionManager) ValidateConnection(conn *grpc.ClientConn) bool { + cm.mu.RLock() + defer cm.mu.RUnlock() + + if cm.closed || conn == nil { + return false + } + + return isConnectionHealthy(conn) +} + +// cleanupIdleConnections uses fixed interval +func (cm *ConnectionManager) cleanupIdleConnections() { + interval := defaultCleanupTickerInterval + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + cm.mu.RLock() + if cm.closed { + cm.mu.RUnlock() + return + } + cm.mu.RUnlock() + cm.pool.cleanupIdle() + } +} + +// Close closes the connection manager and all connections (thread-safe) +func (cm *ConnectionManager) Close() error { + cm.mu.Lock() + defer cm.mu.Unlock() + + if cm.closed { + return nil + } + + cm.closed = true + + if cm.pool != nil { + cm.pool.closeAllConnections() + } + + return nil +} + +// GetCircuitBreakerState returns the current state of the circuit breaker +func (cm *ConnectionManager) GetCircuitBreakerState() CircuitBreakerState { + cm.mu.RLock() + defer cm.mu.RUnlock() + + if cm.closed || cm.circuitBreaker == nil { + return CircuitBreakerClosed + } + + return cm.circuitBreaker.GetState() +} + +// buildDialOptions simplified: insecure creds, per-RPC creds, deadlines, logging, retry +func buildDialOptions(ctx context.Context, cfg *Config) ([]grpc.DialOption, error) { + var opts []grpc.DialOption + transportCreds, err := buildTransportCredentials(cfg) + if err != nil { + return nil, fmt.Errorf("failed to build transport credentials: %w", err) + } + opts = append(opts, grpc.WithTransportCredentials(transportCreds)) + // Per-RPC credentials + if cfg.BearerToken != "" || len(cfg.AuthHeaders) > 0 { + opts = append(opts, grpc.WithPerRPCCredentials(headerCreds{token: cfg.BearerToken, headers: cfg.AuthHeaders, tlsEnabled: false})) + } + // Interceptors: deadlines, logging, retry + unaryInterceptors, streamInterceptors := buildInterceptors(cfg) + if len(unaryInterceptors) > 0 { + opts = append(opts, grpc.WithChainUnaryInterceptor(unaryInterceptors...)) + } + if len(streamInterceptors) > 0 { + opts = append(opts, grpc.WithChainStreamInterceptor(streamInterceptors...)) + } + return opts, nil +} + +// buildTransportCredentials returns insecure credentials (TLS disabled) +func buildTransportCredentials(cfg *Config) (credentials.TransportCredentials, error) { + return insecure.NewCredentials(), nil +} + +// buildDefaultCallOptions left unchanged (not used) + +// buildInterceptors creates gRPC interceptors for observability and best practices +func buildInterceptors(cfg *Config) ([]grpc.UnaryClientInterceptor, []grpc.StreamClientInterceptor) { + var unaryInterceptors []grpc.UnaryClientInterceptor + var streamInterceptors []grpc.StreamClientInterceptor + // Always propagate deadlines + unaryInterceptors = append(unaryInterceptors, deadlineUnaryInterceptor) + streamInterceptors = append(streamInterceptors, deadlineStreamInterceptor) + // Logging + if cfg.Logger != nil { + log := cfg.Logger + unaryInterceptors = append(unaryInterceptors, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + start := time.Now() + err := invoker(ctx, method, req, reply, cc, opts...) + dur := time.Since(start) + st, _ := status.FromError(err) + if err != nil { + logAtLevel(log, cfg.LogLevelError, "grpc unary call failed", method, st.Code().String(), dur) + if cfg.Observer != nil { + cfg.Observer.RecordCall(err) + } + } else { + logAtLevel(log, cfg.LogLevelSuccess, "grpc unary call ok", method, "OK", dur) + if cfg.Observer != nil { + cfg.Observer.RecordCall(nil) + } + } + return err + }) + streamInterceptors = append(streamInterceptors, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + start := time.Now() + cs, err := streamer(ctx, desc, cc, method, opts...) + dur := time.Since(start) + st, _ := status.FromError(err) + if err != nil { + logAtLevel(log, cfg.LogLevelError, "grpc stream open failed", method, st.Code().String(), dur) + if cfg.Observer != nil { + cfg.Observer.RecordCall(err) + } + } else { + logAtLevel(log, cfg.LogLevelSuccess, "grpc stream opened", method, "OK", dur) + if cfg.Observer != nil { + cfg.Observer.RecordCall(nil) + } + } + return cs, err + }) + } + // Retry + if cfg.RetryPolicy != nil { + unaryInterceptors = append(unaryInterceptors, retryUnaryInterceptor(cfg.RetryPolicy)) + } + return unaryInterceptors, streamInterceptors +} + +// deadlineUnaryInterceptor propagates deadlines from context with enhanced handling +func deadlineUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + // Enhanced context deadline handling + ctx = enhanceContextWithDeadlines(ctx) + + // Add method-specific timeout if none exists + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + // Apply default timeout for unary calls + defaultTimeout := 30 * time.Second + newCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + ctx = newCtx + } + + // Ensure context is properly canceled on completion + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + return invoker(ctx, method, req, reply, cc, opts...) +} + +// deadlineStreamInterceptor propagates deadlines for streaming calls with enhanced handling +func deadlineStreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // Enhanced context deadline handling for streams + ctx = enhanceContextWithDeadlines(ctx) + + // Add method-specific timeout for streaming calls if none exists + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + // Apply longer default timeout for streaming calls + defaultTimeout := 5 * time.Minute + newCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + ctx = newCtx + } + + return streamer(ctx, desc, cc, method, opts...) +} + +// enhanceContextWithDeadlines enhances context with proper deadline handling and propagation +func enhanceContextWithDeadlines(ctx context.Context) context.Context { + // Check if we already have a deadline + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + // Calculate remaining time + remaining := time.Until(deadline) + + // If deadline is too close, leave as-is (avoid creating new timers) + minRemainingTime := 100 * time.Millisecond + if remaining < minRemainingTime { + return ctx + } + + // Deadline is reasonable, use as-is + return ctx + } + + // No deadline exists, return original context + return ctx +} + +// retryUnaryInterceptor implements client-side retry logic +func retryUnaryInterceptor(retryPolicy *RetryPolicy) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + var lastErr error + backoff := retryPolicy.InitialBackoff + + for attempt := 0; attempt < retryPolicy.MaxAttempts; attempt++ { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + err := invoker(ctx, method, req, reply, cc, opts...) + if err == nil { + return nil // Success + } + + lastErr = err + + // Check if error is retryable + if !isRetryableError(err, retryPolicy.RetryableStatusCodes) { + return err + } + + // Don't retry on last attempt + if attempt == retryPolicy.MaxAttempts-1 { + break + } + + // Sleep with backoff + select { + case <-time.After(backoff): + backoff = time.Duration(float64(backoff) * retryPolicy.BackoffMultiplier) + if backoff > retryPolicy.MaxBackoff { + backoff = retryPolicy.MaxBackoff + } + case <-ctx.Done(): + return ctx.Err() + } + } + + return lastErr + } +} + +// isRetryableError determines if an error should be retried +func isRetryableError(err error, retryableCodes []codes.Code) bool { + if err == nil { + return false + } + + st, ok := status.FromError(err) + if !ok { + return false + } + + for _, code := range retryableCodes { + if st.Code() == code { + return true + } + } + + return false +} + +// MethodResolver handles method resolution with caching and performance optimizations +type MethodResolver struct { + cache sync.Map // string -> *methodCacheEntry +} + +// methodCacheEntry holds both the method descriptor and message pools +type methodCacheEntry struct { + method *desc.MethodDescriptor +} + +// NewMethodResolver creates a new method resolver +func NewMethodResolver() *MethodResolver { + return &MethodResolver{} +} + +// ResolveMethod resolves a method using reflection or proto files with enhanced caching +func (mr *MethodResolver) ResolveMethod(ctx context.Context, conn *grpc.ClientConn, cfg *Config) (*desc.MethodDescriptor, error) { + // Check cache first + key := mr.cacheKey(cfg) + if cached, ok := mr.cache.Load(key); ok { + entry := cached.(*methodCacheEntry) + return entry.method, nil + } + var method *desc.MethodDescriptor + var err error + if len(cfg.ProtoFiles) > 0 { + method, err = mr.resolveFromProtoFiles(cfg.Method, cfg.ProtoFiles, cfg.IncludePaths) + } else { + method, err = mr.resolveFromReflection(ctx, conn, cfg.Method) + } + if err != nil { + return nil, err + } + entry := &methodCacheEntry{method: method} + mr.cache.Store(key, entry) + mr.cache.Store(method.GetFullyQualifiedName(), entry) + return method, nil +} + +// GetMessagePools removed in minimal mode + +// resolveFromReflection resolves method using gRPC reflection +func (mr *MethodResolver) resolveFromReflection(ctx context.Context, conn *grpc.ClientConn, methodName string) (*desc.MethodDescriptor, error) { + rc := grpcreflect.NewClientAuto(ctx, conn) + defer rc.Reset() + + svcName, mName, err := parseMethodName(methodName) + if err != nil { + return nil, err + } + + svc, err := rc.ResolveService(svcName) + if err != nil { + return nil, fmt.Errorf("failed to resolve service %s: %w", svcName, err) + } + + method := svc.FindMethodByName(mName) + if method == nil { + return nil, fmt.Errorf("method not found: %s", methodName) + } + + return method, nil +} + +// resolveFromProtoFiles resolves method from proto files +func (mr *MethodResolver) resolveFromProtoFiles(methodName string, protoFiles, includePaths []string) (*desc.MethodDescriptor, error) { + var parser protoparse.Parser + if len(includePaths) > 0 { + parser.ImportPaths = includePaths + } + + fds, err := parser.ParseFiles(protoFiles...) + if err != nil { + return nil, fmt.Errorf("failed to parse proto files: %w", err) + } + + svcName, mName, err := parseMethodName(methodName) + if err != nil { + return nil, err + } + + for _, fd := range fds { + for _, svc := range fd.GetServices() { + if svc.GetFullyQualifiedName() == svcName || svc.GetName() == svcName { + if method := svc.FindMethodByName(mName); method != nil { + return method, nil + } + } + } + } + + return nil, fmt.Errorf("method not found in provided proto files: %s", methodName) +} + +// parseMethodName parses a method name like "/pkg.Service/Method" into service and method names +// Optimized to avoid repeated string operations and memory allocations +func parseMethodName(full string) (string, string, error) { + // Fast path: check minimum length and format + if len(full) < minMethodNameLength { // Minimum: "/a/b" + return "", "", fmt.Errorf("invalid method format: %s (too short)", full) + } + + // Remove leading slash efficiently + start := 0 + if full[0] == '/' { + start = 1 + } + + // Find the last slash to separate service and method (single pass) + lastSlash := -1 + for i := len(full) - 1; i >= start; i-- { + if full[i] == '/' { + lastSlash = i + break + } + } + + if lastSlash == -1 || lastSlash == start { + return "", "", fmt.Errorf("invalid method format: %s (expected format: /service/method)", full) + } + + serviceName := full[start:lastSlash] + methodName := full[lastSlash+1:] + + // Validate non-empty (avoid string comparison) + if len(serviceName) == 0 || len(methodName) == 0 { + return "", "", fmt.Errorf("invalid method format: %s (service and method names cannot be empty)", full) + } + + return serviceName, methodName, nil +} + +// RetryConfig holds retry configuration +type RetryConfig struct { + InitialBackoff time.Duration + MaxBackoff time.Duration + MaxRetries int +} + +// DefaultRetryConfig returns sensible retry defaults +func DefaultRetryConfig() RetryConfig { + return RetryConfig{ + InitialBackoff: defaultRetryBackoffInitial, + MaxBackoff: defaultRetryBackoffMax, + MaxRetries: 5, + } +} + +// WithContextRetry performs an operation with exponential backoff retry and context awareness. +// +// Retry Strategy: +// - Implements exponential backoff with configurable initial delay and multiplier +// - Respects maximum backoff duration to prevent excessive wait times +// - Honors context cancellation at any point during retry attempts +// - Uses intelligent error classification for retry decisions +// +// Context Handling: +// - Checks for context cancellation before each retry attempt +// - Cancellation during backoff sleep immediately returns context error +// - Preserves last operation error when context is cancelled +// +// Error Handling: +// - Returns immediately on successful operation (nil error) +// - Accumulates the last error from failed attempts +// - Provides comprehensive error context including attempt count +// - Uses error classification to determine retry eligibility +// +// Usage Pattern: +// This is typically used for transient failures in gRPC operations where +// temporary network issues or service unavailability should be retried. +func WithContextRetry(ctx context.Context, cfg RetryConfig, operation func() error) error { + var lastErr error + backoff := cfg.InitialBackoff + + for attempt := 0; attempt <= cfg.MaxRetries; attempt++ { + // Check context cancellation before each attempt + select { + case <-ctx.Done(): + if lastErr != nil { + return fmt.Errorf("context cancelled, last error: %w", lastErr) + } + return ctx.Err() + default: + } + + if err := operation(); err != nil { + lastErr = err + + // Use error classification to determine if we should retry + classifiedErr := classifyGrpcError("", err) + if classifiedErr != nil && !classifiedErr.IsRetryable() { + // Don't retry non-retryable errors + return fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err) + } + + // Don't sleep after the last attempt + if attempt == cfg.MaxRetries { + break + } + + // Sleep with backoff, but respect context cancellation + select { + case <-time.After(backoff): + // Exponential backoff with jitter + backoff = time.Duration(float64(backoff) * defaultRetryMultiplier) + + // Add small jitter to prevent thundering herd (±10%) + jitterFactor := 0.1 * (2*rand.Float64() - 1) // Random value between -0.1 and 0.1 + jitter := time.Duration(float64(backoff) * jitterFactor) + backoff += jitter + + if backoff > cfg.MaxBackoff { + backoff = cfg.MaxBackoff + } + case <-ctx.Done(): + return fmt.Errorf("context cancelled during backoff, last error: %w", lastErr) + } + } else { + return nil // Success + } + } + + return fmt.Errorf("operation failed after %d attempts: %w", cfg.MaxRetries+1, lastErr) +} + +// ErrorType represents the classification of different types of errors +type ErrorType int + +const ( + ErrorTypeUnknown ErrorType = iota + ErrorTypeConnection + ErrorTypeTimeout + ErrorTypeAuthentication + ErrorTypeAuthorization + ErrorTypeRateLimit + ErrorTypeResourceExhausted + ErrorTypeUnavailable + ErrorTypeInternal + ErrorTypeInvalidArgument + ErrorTypeNotFound + ErrorTypeAlreadyExists + ErrorTypeFailedPrecondition + ErrorTypeAborted + ErrorTypeOutOfRange + ErrorTypeUnimplemented + ErrorTypeDataLoss + ErrorTypeCancelled + ErrorTypeDeadlineExceeded +) + +// String returns the string representation of the error type +func (et ErrorType) String() string { + switch et { + case ErrorTypeConnection: + return "connection" + case ErrorTypeTimeout: + return "timeout" + case ErrorTypeAuthentication: + return "authentication" + case ErrorTypeAuthorization: + return "authorization" + case ErrorTypeRateLimit: + return "rate_limit" + case ErrorTypeResourceExhausted: + return "resource_exhausted" + case ErrorTypeUnavailable: + return "unavailable" + case ErrorTypeInternal: + return "internal" + case ErrorTypeInvalidArgument: + return "invalid_argument" + case ErrorTypeNotFound: + return "not_found" + case ErrorTypeAlreadyExists: + return "already_exists" + case ErrorTypeFailedPrecondition: + return "failed_precondition" + case ErrorTypeAborted: + return "aborted" + case ErrorTypeOutOfRange: + return "out_of_range" + case ErrorTypeUnimplemented: + return "unimplemented" + case ErrorTypeDataLoss: + return "data_loss" + case ErrorTypeCancelled: + return "cancelled" + case ErrorTypeDeadlineExceeded: + return "deadline_exceeded" + default: + return "unknown" + } +} + +// GrpcError represents a classified gRPC error with additional context +type GrpcError struct { + Type ErrorType + Code codes.Code + Message string + Method string + Details []string + Retryable bool + OriginalErr error +} + +// Error implements the error interface +func (e *GrpcError) Error() string { + return fmt.Sprintf("%s (%s): %s", e.Type.String(), e.Method, e.Message) +} + +// Unwrap returns the original error for error wrapping +func (e *GrpcError) Unwrap() error { + return e.OriginalErr +} + +// IsRetryable returns whether this error should be retried +func (e *GrpcError) IsRetryable() bool { + return e.Retryable +} + +// classifyGrpcError analyzes a gRPC error and returns a classified GrpcError +func classifyGrpcError(method string, err error) *GrpcError { + if err == nil { + return nil + } + + grpcErr := &GrpcError{ + Method: method, + OriginalErr: err, + } + + // Check for gRPC status errors + if st, ok := status.FromError(err); ok { + grpcErr.Code = st.Code() + grpcErr.Message = st.Message() + + // Collect details if any + for _, d := range st.Details() { + if pm, ok := d.(proto.Message); ok { + b, _ := protojson.Marshal(pm) + grpcErr.Details = append(grpcErr.Details, string(b)) + } + } + + // Classify based on gRPC status code + switch st.Code() { + case codes.Canceled: + grpcErr.Type = ErrorTypeCancelled + grpcErr.Retryable = false // Don't retry cancelled operations + case codes.Unknown: + grpcErr.Type = ErrorTypeUnknown + grpcErr.Retryable = true // May be transient + case codes.InvalidArgument: + grpcErr.Type = ErrorTypeInvalidArgument + grpcErr.Retryable = false // Client error, don't retry + case codes.DeadlineExceeded: + grpcErr.Type = ErrorTypeDeadlineExceeded + grpcErr.Retryable = true // Network timeout, retry possible + case codes.NotFound: + grpcErr.Type = ErrorTypeNotFound + grpcErr.Retryable = false // Resource doesn't exist + case codes.AlreadyExists: + grpcErr.Type = ErrorTypeAlreadyExists + grpcErr.Retryable = false // Resource conflict + case codes.PermissionDenied: + grpcErr.Type = ErrorTypeAuthorization + grpcErr.Retryable = false // Authorization failure + case codes.ResourceExhausted: + grpcErr.Type = ErrorTypeResourceExhausted + grpcErr.Retryable = true // May be temporary resource exhaustion + case codes.FailedPrecondition: + grpcErr.Type = ErrorTypeFailedPrecondition + grpcErr.Retryable = false // Preconditions not met + case codes.Aborted: + grpcErr.Type = ErrorTypeAborted + grpcErr.Retryable = true // May be transient + case codes.OutOfRange: + grpcErr.Type = ErrorTypeOutOfRange + grpcErr.Retryable = false // Invalid range + case codes.Unimplemented: + grpcErr.Type = ErrorTypeUnimplemented + grpcErr.Retryable = false // Method not implemented + case codes.Internal: + grpcErr.Type = ErrorTypeInternal + grpcErr.Retryable = true // Server internal error, may be transient + case codes.Unavailable: + grpcErr.Type = ErrorTypeUnavailable + grpcErr.Retryable = true // Service unavailable, definitely retry + case codes.DataLoss: + grpcErr.Type = ErrorTypeDataLoss + grpcErr.Retryable = false // Data corruption, don't retry + case codes.Unauthenticated: + grpcErr.Type = ErrorTypeAuthentication + grpcErr.Retryable = false // Authentication failure + default: + grpcErr.Type = ErrorTypeUnknown + grpcErr.Retryable = false // Conservative default + } + + return grpcErr + } + + // Handle non-gRPC errors (connection errors, etc.) + errStr := err.Error() + grpcErr.Message = errStr + + // Classify based on error message patterns + if strings.Contains(errStr, "connection") || strings.Contains(errStr, "dial") || + strings.Contains(errStr, "network") || strings.Contains(errStr, "timeout") { + grpcErr.Type = ErrorTypeConnection + grpcErr.Retryable = true + } else if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") { + grpcErr.Type = ErrorTypeTimeout + grpcErr.Retryable = true + } else { + grpcErr.Type = ErrorTypeUnknown + grpcErr.Retryable = false + } + + return grpcErr +} + +// formatGrpcError returns a richer error string including status code and details +func formatGrpcError(prefix, method string, err error) error { + classifiedErr := classifyGrpcError(method, err) + if classifiedErr != nil { + return classifiedErr + } + return fmt.Errorf("%s (%s): %w", prefix, method, err) +} + +// logAtLevel emits a structured message at a given level +func logAtLevel(log *service.Logger, level string, msg string, method string, code string, dur time.Duration) { + entry := log.With("method", method, "code", code, "duration", dur.String()) + switch strings.ToLower(level) { + case "warn", "warning": + entry.Warnf(msg) + case "info": + entry.Infof(msg) + default: + entry.Debugf(msg) + } +} + +func (mr *MethodResolver) cacheKey(cfg *Config) string { + if len(cfg.ProtoFiles) == 0 { + return cfg.Method + "|reflect" + } + return cfg.Method + "|" + strings.Join(cfg.ProtoFiles, ",") + "|" + strings.Join(cfg.IncludePaths, ",") +} diff --git a/internal/impl/grpc_client/common_unit_test.go b/internal/impl/grpc_client/common_unit_test.go new file mode 100644 index 000000000..fc9200a9d --- /dev/null +++ b/internal/impl/grpc_client/common_unit_test.go @@ -0,0 +1,171 @@ +package grpc_client + +import ( + "context" + "path/filepath" + "runtime" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func Test_injectMetadataIntoContext_Base64Bin(t *testing.T) { + ctx := context.Background() + cfg := &Config{ + AuthHeaders: map[string]string{ + "foo": "bar", + "bin-key": "YmFy", // will be treated as plain in simplified path + }, + BearerToken: "tok", + } + ctx = injectMetadataIntoContext(ctx, cfg) + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatalf("expected metadata in context") + } + if got := md.Get("foo"); len(got) != 1 || got[0] != "bar" { + t.Fatalf("expected foo=bar, got %v", got) + } + if got := md.Get("authorization"); len(got) != 1 || got[0] != "Bearer tok" { + t.Fatalf("expected authorization bearer, got %v", got) + } +} + +func Test_parseMethodName(t *testing.T) { + svc, m, err := parseMethodName("/echo.Echo/Stream") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if svc != "echo.Echo" || m != "Stream" { + t.Fatalf("unexpected parsed: %s %s", svc, m) + } + + if _, _, err := parseMethodName("bad"); err == nil { + t.Fatalf("expected error for invalid method format") + } +} + +func Test_MethodResolver_FromProtoFiles(t *testing.T) { + mr := NewMethodResolver() + // Build include path relative to this file + _, file, _, _ := runtime.Caller(0) + dir := filepath.Dir(file) + include := filepath.Clean(filepath.Join(dir, "../../../cmd/tools/grpc_test_server/pb")) + + cfg := &Config{ + Method: "/echo.Echo/Stream", + ProtoFiles: []string{"echo.proto", "google/protobuf/struct.proto"}, + IncludePaths: []string{include}, + } + method, err := mr.ResolveMethod(context.Background(), nil, cfg) + if err != nil { + t.Fatalf("ResolveMethod failed: %v", err) + } + if method == nil || !method.IsServerStreaming() || method.IsClientStreaming() { + t.Fatalf("unexpected method streaming flags") + } +} + +func Test_CircuitBreaker_Transitions(t *testing.T) { + cb := NewCircuitBreaker(2, 10*time.Millisecond) + if !cb.CanExecute() { + t.Fatalf("expected can execute in closed state") + } + cb.RecordFailure() + if !cb.CanExecute() { + t.Fatalf("should still execute after first failure") + } + cb.RecordFailure() + if cb.CanExecute() { + t.Fatalf("should not execute when open") + } + time.Sleep(15 * time.Millisecond) + if !cb.CanExecute() { + t.Fatalf("expected half-open allowing limited requests") + } + cb.RecordSuccess() + cb.RecordSuccess() + cb.RecordSuccess() // should transition to closed + if cb.GetState() != CircuitBreakerClosed { + t.Fatalf("expected closed after successes, got %v", cb.GetState()) + } +} + +func Test_WithContextRetry_Attempts(t *testing.T) { + cfg := RetryConfig{InitialBackoff: 1 * time.Millisecond, MaxBackoff: 2 * time.Millisecond, MaxRetries: 3} + attempts := 0 + err := WithContextRetry(context.Background(), cfg, func() error { + attempts++ + if attempts < 4 { + return status.Error(codes.Unavailable, "retry me") + } + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attempts != 4 { + t.Fatalf("expected 4 attempts, got %d", attempts) + } +} + +func Test_ClassifyGrpcError_Status(t *testing.T) { + st := status.Error(codes.Unavailable, "svc unavailable") + e := classifyGrpcError("/svc/m", st) + if e == nil || !e.IsRetryable() || e.Type != ErrorTypeUnavailable { + t.Fatalf("unexpected classification: %+v", e) + } +} + +func Test_headerCreds_GetRequestMetadata(t *testing.T) { + h := headerCreds{ + token: "abc", + headers: map[string]string{ + "X-Foo": "bar", + "bad": "line\r\ninjection", + }, + } + md, err := h.GetRequestMetadata(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if md["authorization"] != "Bearer abc" { + t.Fatalf("expected auth bearer, got %v", md["authorization"]) + } + if md["x-foo"] != "bar" { + t.Fatalf("expected x-foo=bar, got %v", md["x-foo"]) + } + if _, ok := md["bad"]; ok { + t.Fatalf("expected invalid header to be skipped") + } +} + +func Test_deadlineUnaryInterceptor_SetsDefault(t *testing.T) { + var hadDeadline bool + inv := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + _, hadDeadline = ctx.Deadline() + return nil + } + _ = deadlineUnaryInterceptor(context.Background(), "/svc/Method", nil, nil, nil, inv) + if !hadDeadline { + t.Fatalf("expected deadline to be set by interceptor") + } +} + +func Test_deadlineUnaryInterceptor_PreservesExisting(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + var seen time.Time + inv := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + seen, _ = ctx.Deadline() + return nil + } + _ = deadlineUnaryInterceptor(ctx, "/svc/Method", nil, nil, nil, inv) + if time.Until(seen) <= 0 { + t.Fatalf("expected existing deadline to be preserved") + } +} diff --git a/internal/impl/grpc_client/connection_pool.go b/internal/impl/grpc_client/connection_pool.go new file mode 100644 index 000000000..80af9e5d2 --- /dev/null +++ b/internal/impl/grpc_client/connection_pool.go @@ -0,0 +1,179 @@ +package grpc_client + +import ( + "context" + "errors" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +// ConnectionPool manages a pool of gRPC connections for performance optimization. +// +// The pool implements a round-robin connection selection strategy with automatic +// connection release. Connections are marked as "in use" temporarily to prevent +// concurrent access conflicts, then automatically released after a short delay. +// +// Thread Safety: All public methods are thread-safe using RWMutex protection. +// The pool supports concurrent access from multiple goroutines. +// +// Lifecycle: Connections are created during pool initialization and replaced +// when they become idle beyond the configured timeout. +type ConnectionPool struct { + connections []connectionEntry // Pool of gRPC connections + mu sync.RWMutex // Protects concurrent access to pool state + cfg *Config // Configuration for connection management + nextIndex int // Round-robin index for connection selection + closed bool // Indicates if pool is closed + ctx context.Context // Parent context for new connections +} + +// connectionEntry represents a single gRPC connection in the pool with metadata +type connectionEntry struct { + conn *grpc.ClientConn // The actual gRPC connection + lastUsed time.Time // Timestamp of last usage for idle cleanup + createdAt time.Time // When this connection was created + failureCount int // Number of consecutive failures + lastFailure time.Time // Time of last failure + healthChecked time.Time // Last time health was checked +} + +// getConnection gets an available connection from the pool with state validation +func (cp *ConnectionPool) getConnection() (*grpc.ClientConn, error) { + cp.mu.Lock() + defer cp.mu.Unlock() + + if cp.closed { + return nil, errors.New("connection pool is closed") + } + for i := 0; i < len(cp.connections); i++ { + idx := (cp.nextIndex + i) % len(cp.connections) + entry := &cp.connections[idx] + if !isConnectionExcessivelyFailing(entry, defaultMaxConnectionFailures, defaultFailureWindow) { + if !isConnectionHealthy(entry.conn) { + recordConnectionFailure(entry) + if newConn, err := createConnection(cp.ctx, cp.cfg); err == nil { + entry.conn.Close() + now := time.Now() + entry.conn = newConn + entry.lastUsed = now + entry.createdAt = now + entry.failureCount = 0 + entry.lastFailure = time.Time{} + entry.healthChecked = now + } else { + continue + } + } + entry.lastUsed = time.Now() + entry.healthChecked = time.Now() + cp.nextIndex = (idx + 1) % len(cp.connections) + recordConnectionSuccess(entry) + return entry.conn, nil + } + } + return nil, errors.New("no connections available in pool") +} + +// cleanupIdle removes connections that have been idle for too long or are unhealthy +func (cp *ConnectionPool) cleanupIdle() { + cp.mu.Lock() + defer cp.mu.Unlock() + if cp.closed { + return + } + now := time.Now() + for i := range cp.connections { + entry := &cp.connections[i] + shouldReplace := false + if now.Sub(entry.lastUsed) > defaultConnectionIdleTimeout { + shouldReplace = true + } + if isConnectionExcessivelyFailing(entry, defaultMaxConnectionFailures, defaultFailureWindow) { + shouldReplace = true + } + if now.Sub(entry.healthChecked) > defaultHealthCheckInterval { + if !isConnectionHealthy(entry.conn) { + recordConnectionFailure(entry) + shouldReplace = true + } else { + entry.healthChecked = now + recordConnectionSuccess(entry) + } + } + if shouldReplace { + if entry.conn != nil { + entry.conn.Close() + } + if newConn, err := createConnection(cp.ctx, cp.cfg); err == nil { + entry.conn = newConn + entry.lastUsed = now + entry.createdAt = now + entry.failureCount = 0 + entry.lastFailure = time.Time{} + entry.healthChecked = now + } + } + } +} + +// closeAllConnections closes all connections in the pool +func (cp *ConnectionPool) closeAllConnections() { + cp.mu.Lock() + defer cp.mu.Unlock() + for _, entry := range cp.connections { + if entry.conn != nil { + entry.conn.Close() + } + } + cp.connections = nil + cp.closed = true +} + +// isConnectionHealthy validates the connection state and health +func isConnectionHealthy(conn *grpc.ClientConn) bool { + if conn == nil { + return false + } + switch conn.GetState() { + case connectivity.Ready, connectivity.Idle, connectivity.Connecting: + return true + default: + return false + } +} + +// helpers (minimal) +func recordConnectionFailure(entry *connectionEntry) { + if entry == nil { + return + } + entry.failureCount++ + entry.lastFailure = time.Now() +} + +func recordConnectionSuccess(entry *connectionEntry) { + if entry == nil { + return + } + if entry.failureCount > 0 { + entry.failureCount = 0 + entry.lastFailure = time.Time{} + } +} + +func isConnectionExcessivelyFailing(entry *connectionEntry, maxFailures int, failureWindow time.Duration) bool { + if entry == nil { + return true + } + if entry.failureCount >= maxFailures { + if time.Since(entry.lastFailure) <= failureWindow { + return true + } + entry.failureCount = 0 + entry.lastFailure = time.Time{} + } + return false +} diff --git a/internal/impl/grpc_client/input_grpc_client.go b/internal/impl/grpc_client/input_grpc_client.go new file mode 100644 index 000000000..8cee564fc --- /dev/null +++ b/internal/impl/grpc_client/input_grpc_client.go @@ -0,0 +1,347 @@ +package grpc_client + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/dynamic/grpcdynamic" + "google.golang.org/protobuf/encoding/protojson" + structpb "google.golang.org/protobuf/types/known/structpb" + + "github.com/warpstreamlabs/bento/public/service" +) + +func genericInputSpec() *service.ConfigSpec { + return createBaseConfigSpec(). + Summary("Call an arbitrary gRPC method (unary or server-stream) using reflection to resolve types with enhanced security and performance"). + Field(service.NewStringField(fieldRPCType).Default("server_stream")). + Field(service.NewStringField(fieldRequestJSON).Default("{}").Description("JSON request body used for unary or initial server-stream request")) +} + +// genericInput handles both unary and server-streaming gRPC input +type genericInput struct { + cfg *Config + connMgr *ConnectionManager + methodResolver *MethodResolver + reqIS *service.InterpolatedString + method *desc.MethodDescriptor + + // Server streaming state with proper cleanup + mu sync.Mutex + streamCtx context.Context + streamCancel context.CancelFunc + stream *grpcdynamic.ServerStream + streamOpen bool + shutdown bool + retryConfig RetryConfig +} + +func newGenericInput(conf *service.ParsedConfig, res *service.Resources) (service.Input, error) { + cfg, err := ParseConfigFromService(conf) + if err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + // Attach logger for common code + cfg.Logger = res.Logger() + + reqIS, err := service.NewInterpolatedString(cfg.RequestJSON) + if err != nil { + return nil, fmt.Errorf("failed to create interpolated string: %w", err) + } + + connMgr, err := NewConnectionManager(context.Background(), cfg) + if err != nil { + return nil, fmt.Errorf("failed to create connection manager: %w", err) + } + + methodResolver := NewMethodResolver() + + conn, err := connMgr.GetConnection() + if err != nil { + connMgr.Close() + return nil, fmt.Errorf("failed to get connection: %w", err) + } + + method, err := methodResolver.ResolveMethod(context.Background(), conn, cfg) + if err != nil { + connMgr.Close() + return nil, fmt.Errorf("failed to resolve method: %w", err) + } + + return &genericInput{ + cfg: cfg, + connMgr: connMgr, + methodResolver: methodResolver, + reqIS: reqIS, + method: method, + retryConfig: func() RetryConfig { + r := DefaultRetryConfig() + if cfg.RetryPolicy != nil { + r.InitialBackoff = cfg.RetryPolicy.InitialBackoff + r.MaxBackoff = cfg.RetryPolicy.MaxBackoff + // Approximate: MaxRetries = MaxAttempts-1 + if cfg.RetryPolicy.MaxAttempts > 0 { + r.MaxRetries = cfg.RetryPolicy.MaxAttempts - 1 + } + } + return r + }(), + }, nil +} + +func (g *genericInput) Connect(_ context.Context) error { + return nil +} + +func (g *genericInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { + g.mu.Lock() + if g.shutdown { + g.mu.Unlock() + return nil, nil, service.ErrNotConnected + } + g.mu.Unlock() + + if g.method == nil { + return nil, nil, service.ErrNotConnected + } + + // Build request message from JSON with optional pooling + requestMsg := dynamic.NewMessage(g.method.GetInputType()) + + reqJSON, rerr := g.reqIS.TryString(service.NewMessage(nil)) + if rerr != nil { + return nil, nil, fmt.Errorf("failed to render request_json: %w", rerr) + } + if reqJSON == "" { + reqJSON = "{}" + } + if uerr := requestMsg.UnmarshalJSON([]byte(reqJSON)); uerr != nil { + return nil, nil, fmt.Errorf("failed to unmarshal request JSON: %w", uerr) + } + + switch g.cfg.RPCType { + case "unary": + return g.handleUnaryCall(ctx, requestMsg) + case "server_stream": + return g.handleServerStreamCall(ctx, requestMsg, reqJSON) + default: + return nil, nil, fmt.Errorf("unsupported rpc_type for input: %s", g.cfg.RPCType) + } +} + +func (g *genericInput) handleUnaryCall(ctx context.Context, requestMsg *dynamic.Message) (*service.Message, service.AckFunc, error) { + // Validate method type + if g.method.IsServerStreaming() || g.method.IsClientStreaming() { + return nil, nil, fmt.Errorf("method %s is not unary", g.method.GetFullyQualifiedName()) + } + + conn, err := g.connMgr.GetConnection() + if err != nil { + return nil, nil, fmt.Errorf("failed to get connection: %w", err) + } + + stub := grpcdynamic.NewStub(conn) + + // Enhanced context handling with proper deadline propagation + callCtx := g.enhanceCallContext(ctx) + var cancel context.CancelFunc + + if g.cfg.CallTimeout > 0 { + callCtx, cancel = context.WithTimeout(callCtx, g.cfg.CallTimeout) + defer cancel() + } else if _, hasDeadline := callCtx.Deadline(); !hasDeadline { + // Apply default timeout for unary input calls + callCtx, cancel = context.WithTimeout(callCtx, 30*time.Second) + defer cancel() + } + + resp, err := stub.InvokeRpc(callCtx, g.method, requestMsg) + if err != nil { + return nil, nil, formatGrpcError("grpc_client unary call failed", g.method.GetFullyQualifiedName(), err) + } + + // Handle different response types + var respBytes []byte + switch v := resp.(type) { + case *dynamic.Message: + respBytes, err = v.MarshalJSON() + case *structpb.Struct: + m := protojson.MarshalOptions{EmitUnpopulated: false, UseProtoNames: false, AllowPartial: true, Multiline: false, Indent: ""} + respBytes, err = m.Marshal(v) + default: + return nil, nil, fmt.Errorf("unexpected response type from unary call: %T", resp) + } + + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + msg := service.NewMessage(respBytes) + return msg, func(context.Context, error) error { return nil }, nil +} + +func (g *genericInput) handleServerStreamCall(ctx context.Context, requestMsg *dynamic.Message, reqJSON string) (*service.Message, service.AckFunc, error) { + // Validate method type + if !g.method.IsServerStreaming() || g.method.IsClientStreaming() { + return nil, nil, fmt.Errorf("method %s is not server-streaming", g.method.GetFullyQualifiedName()) + } + + // Ensure stream is open + if err := g.ensureStreamOpen(ctx, requestMsg); err != nil { + return nil, nil, fmt.Errorf("failed to open stream: %w", err) + } + + for { + g.mu.Lock() + if g.shutdown || g.stream == nil { + g.mu.Unlock() + return nil, nil, service.ErrNotConnected + } + stream := g.stream + g.mu.Unlock() + + resp, err := stream.RecvMsg() + if err == nil { + // Handle different response types + var respBytes []byte + var marshalErr error + + switch v := resp.(type) { + case *dynamic.Message: + // Direct dynamic message + respBytes, marshalErr = v.MarshalJSON() + case *structpb.Struct: + // google.protobuf.Struct + m := protojson.MarshalOptions{EmitUnpopulated: false, UseProtoNames: false, AllowPartial: true, Multiline: false, Indent: ""} + respBytes, marshalErr = m.Marshal(v) + default: + return nil, nil, fmt.Errorf("unexpected stream response type: %T", resp) + } + + if marshalErr != nil { + return nil, nil, fmt.Errorf("failed to marshal stream response: %w", marshalErr) + } + + msg := service.NewMessage(respBytes) + return msg, func(context.Context, error) error { return nil }, nil + } + + if errors.Is(err, io.EOF) { + return nil, nil, service.ErrEndOfInput + } + + // Stream failed, attempt to reopen with retry + if reopenErr := g.reopenStreamWithRetry(ctx, reqJSON); reopenErr != nil { + return nil, nil, formatGrpcError("grpc_client failed to reopen server stream", g.method.GetFullyQualifiedName(), reopenErr) + } + } +} + +func (g *genericInput) ensureStreamOpen(ctx context.Context, requestMsg *dynamic.Message) error { + g.mu.Lock() + defer g.mu.Unlock() + + if g.streamOpen && g.stream != nil { + return nil + } + + return g.openStreamLocked(ctx, requestMsg) +} + +func (g *genericInput) openStreamLocked(ctx context.Context, requestMsg *dynamic.Message) error { + // Close existing stream if any + g.closeStreamLocked() + + conn, err := g.connMgr.GetConnection() + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + + stub := grpcdynamic.NewStub(conn) + + // Enhanced context handling for server streaming + streamCtx := g.enhanceCallContext(ctx) + var cancel context.CancelFunc + + if g.cfg.CallTimeout > 0 { + streamCtx, cancel = context.WithTimeout(streamCtx, g.cfg.CallTimeout) + } else { + // Apply default timeout for server streaming + defaultStreamTimeout := 15 * time.Minute + streamCtx, cancel = context.WithTimeout(streamCtx, defaultStreamTimeout) + } + + stream, err := stub.InvokeRpcServerStream(streamCtx, g.method, requestMsg) + if err != nil { + cancel() + return fmt.Errorf("failed to invoke server stream: %w", err) + } + + g.streamCtx = streamCtx + g.streamCancel = cancel + g.stream = stream + g.streamOpen = true + + return nil +} + +// enhanceCallContext enhances the context for gRPC calls with proper deadline and metadata handling +func (g *genericInput) enhanceCallContext(ctx context.Context) context.Context { + return enhanceCallContext(ctx, g.cfg, func(c context.Context) context.Context { + return injectMetadataIntoContext(c, g.cfg) + }) +} + +func (g *genericInput) reopenStreamWithRetry(ctx context.Context, reqJSON string) error { + // Use message pool if enabled for better performance + requestMsg := dynamic.NewMessage(g.method.GetInputType()) + + if err := requestMsg.UnmarshalJSON([]byte(reqJSON)); err != nil { + return fmt.Errorf("failed to unmarshal request for retry: %w", err) + } + + return WithContextRetry(ctx, g.retryConfig, func() error { + g.mu.Lock() + defer g.mu.Unlock() + + if g.shutdown { + return errors.New("input is shutting down") + } + + return g.openStreamLocked(ctx, requestMsg) + }) +} + +func (g *genericInput) closeStreamLocked() { + if g.streamCancel != nil { + g.streamCancel() + g.streamCancel = nil + } + g.stream = nil + g.streamOpen = false +} + +func (g *genericInput) Close(ctx context.Context) error { + g.mu.Lock() + g.shutdown = true + g.closeStreamLocked() + g.mu.Unlock() + + if g.connMgr != nil { + return g.connMgr.Close() + } + return nil +} + +func init() { + _ = service.RegisterInput("grpc_client", genericInputSpec(), func(conf *service.ParsedConfig, res *service.Resources) (service.Input, error) { + return newGenericInput(conf, res) + }) +} diff --git a/internal/impl/grpc_client/input_grpc_client_test.go b/internal/impl/grpc_client/input_grpc_client_test.go new file mode 100644 index 000000000..f8964df72 --- /dev/null +++ b/internal/impl/grpc_client/input_grpc_client_test.go @@ -0,0 +1,42 @@ +package grpc_client + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc/metadata" +) + +func TestInput_injectMetadataIntoContext(t *testing.T) { + cfg := &Config{ + BearerToken: "t", + AuthHeaders: map[string]string{"a": "1", "b": "2"}, + } + ctx := context.Background() + ctx = injectMetadataIntoContext(ctx, cfg) + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("expected outgoing metadata in context") + } + if v := md.Get("authorization"); len(v) == 0 || v[0] != "Bearer t" { + t.Fatalf("authorization = %v", v) + } + if v := md.Get("a"); len(v) == 0 || v[0] != "1" { + t.Fatalf("a = %v", v) + } + if v := md.Get("b"); len(v) == 0 || v[0] != "2" { + t.Fatalf("b = %v", v) + } +} + +func TestInput_enhanceCallContext(t *testing.T) { + g := &genericInput{cfg: &Config{}} + parent, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx := g.enhanceCallContext(parent) + // No-op now; ensure it doesn't strip the context + if ctx == nil { + t.Fatal("expected non-nil context") + } +} diff --git a/internal/impl/grpc_client/integration_test.go b/internal/impl/grpc_client/integration_test.go new file mode 100644 index 000000000..dd38d767f --- /dev/null +++ b/internal/impl/grpc_client/integration_test.go @@ -0,0 +1,147 @@ +package grpc_client + +import ( + "context" + "os/exec" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/dynamic/grpcdynamic" + "google.golang.org/grpc/codes" +) + +func startTestServer(t *testing.T) func() { + t.Helper() + cmd := exec.Command("go", "run", "./cmd/tools/grpc_test_server") + cmd.Dir = repoRoot(t) + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start test server: %v", err) + } + time.Sleep(1 * time.Second) + return func() { _ = cmd.Process.Kill(); _ = cmd.Wait() } +} + +func repoRoot(t *testing.T) string { + _, file, _, _ := runtime.Caller(0) + dir := filepath.Dir(file) + return filepath.Clean(filepath.Join(dir, "../../../")) +} + +func TestIntegration_ServerStream_OK(t *testing.T) { + stop := startTestServer(t) + defer stop() + cfg := &Config{ + Address: "127.0.0.1:50051", + Method: "/echo.Echo/Stream", + RPCType: "server_stream", + ProtoFiles: []string{"echo.proto", "google/protobuf/struct.proto"}, + IncludePaths: []string{filepath.Join(repoRoot(t), "cmd/tools/grpc_test_server/pb")}, + RetryPolicy: &RetryPolicy{MaxAttempts: 2, InitialBackoff: 10 * time.Millisecond, MaxBackoff: 20 * time.Millisecond, BackoffMultiplier: 2, RetryableStatusCodes: []codes.Code{codes.Unavailable}}, + ConnectTimeout: 2 * time.Second, + } + cm, err := NewConnectionManager(context.Background(), cfg) + if err != nil { + t.Fatalf("cm: %v", err) + } + defer cm.Close() + conn, err := cm.GetConnection() + if err != nil { + t.Fatalf("conn: %v", err) + } + mr := NewMethodResolver() + m, err := mr.ResolveMethod(context.Background(), conn, cfg) + if err != nil { + t.Fatalf("resolve: %v", err) + } + stub := grpcdynamic.NewStub(conn) + in := dynamic.NewMessage(m.GetInputType()) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err = stub.InvokeRpcServerStream(ctx, m, in) + if err != nil { + t.Fatalf("invoke stream: %v", err) + } +} + +func TestIntegration_ClientStream_OK(t *testing.T) { + stop := startTestServer(t) + defer stop() + cfg := &Config{ + Address: "127.0.0.1:50051", + Method: "/ingest.Ingest/Stream", + RPCType: "client_stream", + ProtoFiles: []string{"ingest.proto", "google/protobuf/struct.proto"}, + IncludePaths: []string{filepath.Join(repoRoot(t), "cmd/tools/grpc_test_server/pb")}, + RetryPolicy: &RetryPolicy{MaxAttempts: 2, InitialBackoff: 10 * time.Millisecond, MaxBackoff: 20 * time.Millisecond, BackoffMultiplier: 2, RetryableStatusCodes: []codes.Code{codes.Unavailable}}, + ConnectTimeout: 2 * time.Second, + } + cm, err := NewConnectionManager(context.Background(), cfg) + if err != nil { + t.Fatalf("cm: %v", err) + } + defer cm.Close() + conn, err := cm.GetConnection() + if err != nil { + t.Fatalf("conn: %v", err) + } + mr := NewMethodResolver() + m, err := mr.ResolveMethod(context.Background(), conn, cfg) + if err != nil { + t.Fatalf("resolve: %v", err) + } + stub := grpcdynamic.NewStub(conn) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + cs, err := stub.InvokeRpcClientStream(ctx, m) + if err != nil { + t.Fatalf("open client stream: %v", err) + } + for i := 0; i < 3; i++ { + if err := cs.SendMsg(dynamic.NewMessage(m.GetInputType())); err != nil { + t.Fatalf("send: %v", err) + } + } + if _, err := cs.CloseAndReceive(); err != nil { + t.Fatalf("close/recv: %v", err) + } +} + +func TestIntegration_CircuitBreaker_Transitions_Unreachable(t *testing.T) { + cfg := &Config{ + Address: "127.0.0.1:59999", + Method: "/echo.Echo/Stream", + RPCType: "server_stream", + RetryPolicy: &RetryPolicy{MaxAttempts: 1, InitialBackoff: 10 * time.Millisecond, MaxBackoff: 10 * time.Millisecond, BackoffMultiplier: 2}, + ConnectTimeout: 200 * time.Millisecond, + } + cm, err := NewConnectionManager(context.Background(), cfg) + if err == nil { + defer cm.Close() + } + if err == nil { + _, _ = cm.GetConnection() + _, _ = cm.GetConnection() + if cm.GetCircuitBreakerState() == CircuitBreakerClosed { + t.Fatalf("expected breaker to open after failures") + } + // wait a bit; internal breaker uses defaults + time.Sleep(300 * time.Millisecond) + _ = cm.GetCircuitBreakerState() + } +} + +func TestIntegration_ConnectTimeout_Fails(t *testing.T) { + cfg := &Config{ + Address: "10.255.255.1:65533", + Method: "/echo.Echo/Stream", + RPCType: "server_stream", + ConnectTimeout: 200 * time.Millisecond, + } + _, err := NewConnectionManager(context.Background(), cfg) + if err == nil { + t.Fatalf("expected connection manager creation to fail due to connect_timeout") + } +} diff --git a/internal/impl/grpc_client/output_grpc_client.go b/internal/impl/grpc_client/output_grpc_client.go new file mode 100644 index 000000000..e012ee1a0 --- /dev/null +++ b/internal/impl/grpc_client/output_grpc_client.go @@ -0,0 +1,686 @@ +package grpc_client + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/dynamic/grpcdynamic" + "google.golang.org/protobuf/encoding/protojson" + structpb "google.golang.org/protobuf/types/known/structpb" + + "github.com/warpstreamlabs/bento/public/service" +) + +func genericOutputSpec() *service.ConfigSpec { + return createBaseConfigSpec(). + Summary("Call an arbitrary gRPC method (unary, client_stream, or bidi) using reflection to resolve types with enhanced security and performance"). + Field(service.NewStringField(fieldRPCType).Default("unary").Description("One of: unary, client_stream, bidi")). + Field(service.NewOutputMaxInFlightField()) +} + +// StreamSession represents a streaming gRPC session with comprehensive lifecycle management. +// +// StreamSession handles both client-streaming and bidirectional streaming patterns. +// It provides thread-safe access to stream state and automatic resource cleanup. +// +// Lifecycle Management: +// - Tracks session creation time and last usage for timeout enforcement +// - Maintains context cancellation for graceful stream termination +// - Provides thread-safe state management for concurrent access +// +// Stream Types Supported: +// - *grpcdynamic.ClientStream for client-streaming RPCs +// - *grpcdynamic.BidiStream for bidirectional streaming RPCs +// +// Thread Safety: All methods are thread-safe using RWMutex protection. +type StreamSession struct { + stream interface{} // Can be *grpcdynamic.ClientStream or *grpcdynamic.BidiStream + lastUse time.Time // Last time this session was used for idle timeout + openedAt time.Time // When this session was created for max lifetime + cancel context.CancelFunc // Cancels the stream context for graceful shutdown + closed bool // Indicates if session has been closed + mu sync.RWMutex // Protects concurrent access to session state +} + +// NewStreamSession creates a new stream session +func NewStreamSession(stream interface{}, cancel context.CancelFunc) *StreamSession { + now := time.Now() + return &StreamSession{ + stream: stream, + lastUse: now, + openedAt: now, + cancel: cancel, + } +} + +// UpdateLastUse updates the last use time (thread-safe) +func (s *StreamSession) UpdateLastUse() { + s.mu.Lock() + defer s.mu.Unlock() + s.lastUse = time.Now() +} + +// GetLastUse returns the last use time (thread-safe) +func (s *StreamSession) GetLastUse() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastUse +} + +// GetOpenedAt returns when the session was opened (thread-safe) +func (s *StreamSession) GetOpenedAt() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.openedAt +} + +// Close closes the session (thread-safe) +func (s *StreamSession) Close() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return + } + + s.closed = true + + // Close the appropriate stream type + switch stream := s.stream.(type) { + case *grpcdynamic.ClientStream: + _, _ = stream.CloseAndReceive() + case *grpcdynamic.BidiStream: + _ = stream.CloseSend() + } + + if s.cancel != nil { + s.cancel() + } +} + +// IsClosed returns whether the session is closed (thread-safe) +func (s *StreamSession) IsClosed() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.closed +} + +// GetStream returns the underlying stream (thread-safe) +func (s *StreamSession) GetStream() interface{} { + s.mu.RLock() + defer s.mu.RUnlock() + return s.stream +} + +// SessionManager manages streaming sessions with automatic cleanup and lifecycle enforcement. +// +// The SessionManager coordinates multiple concurrent streaming sessions, applying +// configurable timeout policies and providing centralized session lifecycle management. +// +// Cleanup Strategy: +// - Runs a background goroutine that periodically sweeps for expired sessions +// - Enforces both idle timeout (time since last use) and max lifetime policies +// - Performs graceful session shutdown with proper resource cleanup +// +// Concurrency Model: +// - Thread-safe session storage using RWMutex protection +// - Supports concurrent session creation, access, and cleanup +// - Prevents resource leaks through systematic session tracking +// +// Session Routing: +// - Sessions are identified by string keys (typically from message metadata) +// - Enables message routing to appropriate streaming contexts +// - Supports session-based stateful streaming patterns +type SessionManager struct { + sessions map[string]*StreamSession // Active sessions indexed by session key + mu sync.RWMutex // Protects concurrent access to sessions map + stopCh chan struct{} // Signals cleanup goroutine to stop + stopped bool // Indicates if manager is shut down + idleTimeout time.Duration // Time after which idle sessions are closed + maxLifetime time.Duration // Maximum time a session can remain open + log *service.Logger // Logger for session lifecycle events +} + +// NewSessionManager creates a new session manager +func NewSessionManager(idleTimeout, maxLifetime time.Duration, log *service.Logger) *SessionManager { + sm := &SessionManager{ + sessions: make(map[string]*StreamSession), + stopCh: make(chan struct{}), + idleTimeout: idleTimeout, + maxLifetime: maxLifetime, + log: log, + } + + // Start cleanup goroutine + go sm.cleanup() + + return sm +} + +// GetSession returns an existing session or nil +func (sm *SessionManager) GetSession(key string) *StreamSession { + sm.mu.RLock() + defer sm.mu.RUnlock() + + session, exists := sm.sessions[key] + if !exists || session.IsClosed() { + return nil + } + + return session +} + +// SetSession stores a session +func (sm *SessionManager) SetSession(key string, session *StreamSession) { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Close existing session if any + if existing, exists := sm.sessions[key]; exists { + existing.Close() + } + + sm.sessions[key] = session +} + +// RemoveSession removes and closes a session +func (sm *SessionManager) RemoveSession(key string) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if session, exists := sm.sessions[key]; exists { + session.Close() + delete(sm.sessions, key) + } +} + +// Close closes all sessions and stops the manager +func (sm *SessionManager) Close() { + sm.mu.Lock() + if sm.stopped { + sm.mu.Unlock() + return + } + sm.stopped = true + + // Close all sessions + for key, session := range sm.sessions { + session.Close() + delete(sm.sessions, key) + } + sm.mu.Unlock() + + // Stop cleanup goroutine + close(sm.stopCh) +} + +// cleanup runs in a goroutine to clean up expired sessions +func (sm *SessionManager) cleanup() { + ticker := time.NewTicker(defaultSessionSweepInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + sm.sweepExpiredSessions() + case <-sm.stopCh: + return + } + } +} + +// sweepExpiredSessions removes expired sessions based on idle timeout and max lifetime. +// +// This method implements the core cleanup logic for session management: +// +// 1. Idle Timeout Enforcement: +// - Checks if sessions haven't been used within the configured idle timeout +// - Removes sessions that have been inactive too long to free resources +// +// 2. Max Lifetime Enforcement: +// - Ensures sessions don't exceed their maximum allowed lifetime +// - Prevents indefinitely long-running sessions that could cause resource leaks +// +// 3. Graceful Cleanup: +// - Properly closes each expired session before removal +// - Calls session.Close() to trigger context cancellation and stream cleanup +// - Removes session from the active sessions map +// +// Thread Safety: Acquires write lock for the duration of the sweep operation. +func (sm *SessionManager) sweepExpiredSessions() { + sm.mu.Lock() + defer sm.mu.Unlock() + + if sm.stopped { + return + } + + now := time.Now() + for key, session := range sm.sessions { + shouldRemove := false + + if sm.idleTimeout > 0 && now.Sub(session.GetLastUse()) > sm.idleTimeout { + shouldRemove = true + if sm.log != nil { + sm.log.Debugf("Removing idle session %s", key) + } + } else if sm.maxLifetime > 0 && now.Sub(session.GetOpenedAt()) > sm.maxLifetime { + shouldRemove = true + if sm.log != nil { + sm.log.Debugf("Removing expired session %s", key) + } + } + + if shouldRemove { + session.Close() + delete(sm.sessions, key) + } + } +} + +// UnifiedOutput handles all gRPC output types with shared implementation +type UnifiedOutput struct { + cfg *Config + connMgr *ConnectionManager + methodResolver *MethodResolver + method *desc.MethodDescriptor + sessionMgr *SessionManager + retryConfig RetryConfig + + // Streaming state + mu sync.Mutex + shutdown bool +} + +func newUnifiedOutput(conf *service.ParsedConfig, res *service.Resources) (service.Output, int, error) { + cfg, err := ParseConfigFromService(conf) + if err != nil { + return nil, 0, fmt.Errorf("failed to parse config: %w", err) + } + + maxInFlight, _ := conf.FieldMaxInFlight() + + // Attach logger for common code + cfg.Logger = res.Logger() + + connMgr, err := NewConnectionManager(context.Background(), cfg) + if err != nil { + return nil, 0, fmt.Errorf("failed to create connection manager: %w", err) + } + + methodResolver := NewMethodResolver() + + conn, err := connMgr.GetConnection() + if err != nil { + connMgr.Close() + return nil, 0, fmt.Errorf("failed to get connection: %w", err) + } + + method, err := methodResolver.ResolveMethod(context.Background(), conn, cfg) + if err != nil { + connMgr.Close() + return nil, 0, fmt.Errorf("failed to resolve method: %w", err) + } + + // Validate method type based on RPC type + if err := validateMethodType(method, cfg.RPCType); err != nil { + connMgr.Close() + return nil, 0, err + } + + // Create session manager for streaming types + var sessionMgr *SessionManager + if isStreamingType(cfg.RPCType) { + // Fixed defaults in minimal mode + sessionMgr = NewSessionManager(60*time.Second, 10*time.Minute, res.Logger()) + } + + return &UnifiedOutput{ + cfg: cfg, + connMgr: connMgr, + methodResolver: methodResolver, + method: method, + sessionMgr: sessionMgr, + retryConfig: func() RetryConfig { + r := DefaultRetryConfig() + if cfg.RetryPolicy != nil { + r.InitialBackoff = cfg.RetryPolicy.InitialBackoff + r.MaxBackoff = cfg.RetryPolicy.MaxBackoff + if cfg.RetryPolicy.MaxAttempts > 0 { + r.MaxRetries = cfg.RetryPolicy.MaxAttempts - 1 + } + } + return r + }(), + }, maxInFlight, nil +} + +// validateMethodType validates that the method matches the expected RPC type +func validateMethodType(method *desc.MethodDescriptor, rpcType string) error { + switch rpcType { + case "", "unary": + if method.IsServerStreaming() || method.IsClientStreaming() { + return fmt.Errorf("method %s is not unary", method.GetFullyQualifiedName()) + } + case "client_stream": + if !method.IsClientStreaming() || method.IsServerStreaming() { + return fmt.Errorf("method %s is not client-streaming", method.GetFullyQualifiedName()) + } + case "bidi": + if !method.IsClientStreaming() || !method.IsServerStreaming() { + return fmt.Errorf("method %s is not bidirectional", method.GetFullyQualifiedName()) + } + default: + return fmt.Errorf("unsupported rpc_type: %s", rpcType) + } + return nil +} + +// isStreamingType returns true if the RPC type requires streaming +func isStreamingType(rpcType string) bool { + return rpcType == "client_stream" || rpcType == "bidi" +} + +func (u *UnifiedOutput) Connect(ctx context.Context) error { + return nil +} + +func (u *UnifiedOutput) Write(ctx context.Context, msg *service.Message) error { + u.mu.Lock() + if u.shutdown { + u.mu.Unlock() + return service.ErrNotConnected + } + u.mu.Unlock() + + if u.method == nil { + return service.ErrNotConnected + } + + requestMsg := dynamic.NewMessage(u.method.GetInputType()) + + msgBytes, err := msg.AsBytes() + if err != nil { + return fmt.Errorf("failed to get message bytes: %w", err) + } + if len(msgBytes) == 0 { + msgBytes = []byte("{}") + } + if err := requestMsg.UnmarshalJSON(msgBytes); err != nil { + return fmt.Errorf("failed to unmarshal message JSON: %w", err) + } + + switch u.cfg.RPCType { + case "", "unary": + return u.handleUnaryWrite(ctx, requestMsg) + case "client_stream": + return u.handleClientStreamWrite(ctx, requestMsg, msg) + case "bidi": + return u.handleBidiWrite(ctx, requestMsg, msg) + default: + return fmt.Errorf("unsupported rpc_type: %s", u.cfg.RPCType) + } +} + +func (u *UnifiedOutput) handleUnaryWrite(ctx context.Context, requestMsg *dynamic.Message) error { + conn, err := u.connMgr.GetConnection() + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + + stub := grpcdynamic.NewStub(conn) + + // Enhanced context handling with proper deadline propagation + callCtx := u.enhanceCallContext(ctx) + var cancel context.CancelFunc + + if u.cfg.CallTimeout > 0 { + callCtx, cancel = context.WithTimeout(callCtx, u.cfg.CallTimeout) + defer cancel() + } else if _, hasDeadline := callCtx.Deadline(); !hasDeadline { + // Apply default timeout if none specified + callCtx, cancel = context.WithTimeout(callCtx, 30*time.Second) + defer cancel() + } + + _, err = stub.InvokeRpc(callCtx, u.method, requestMsg) + if err != nil { + return formatGrpcError("grpc_client unary call failed", u.method.GetFullyQualifiedName(), err) + } + + return nil +} + +// enhanceCallContext enhances the context for gRPC calls with proper deadline and metadata handling +func (u *UnifiedOutput) enhanceCallContext(ctx context.Context) context.Context { + return enhanceCallContext(ctx, u.cfg, func(c context.Context) context.Context { + return injectMetadataIntoContext(c, u.cfg) + }) +} + +func (u *UnifiedOutput) handleClientStreamWrite(ctx context.Context, requestMsg *dynamic.Message, msg *service.Message) error { + sessionKey := "default" // Client streams don't use session keys + + session := u.sessionMgr.GetSession(sessionKey) + if session == nil { + if err := u.createClientStreamSession(ctx, sessionKey); err != nil { + return fmt.Errorf("failed to create client stream session: %w", err) + } + session = u.sessionMgr.GetSession(sessionKey) + } + + if session == nil { + return errors.New("failed to get client stream session") + } + + return WithContextRetry(ctx, u.retryConfig, func() error { + session.UpdateLastUse() + + clientStream, ok := session.GetStream().(*grpcdynamic.ClientStream) + if !ok { + return errors.New("invalid client stream type") + } + + if err := clientStream.SendMsg(requestMsg); err != nil { + // Remove failed session and retry will recreate it + u.sessionMgr.RemoveSession(sessionKey) + if recreateErr := u.createClientStreamSession(ctx, sessionKey); recreateErr != nil { + return fmt.Errorf("failed to recreate client stream: %w", recreateErr) + } + + newSession := u.sessionMgr.GetSession(sessionKey) + if newSession == nil { + return errors.New("failed to get recreated client stream session") + } + + newClientStream, ok := newSession.GetStream().(*grpcdynamic.ClientStream) + if !ok { + return errors.New("invalid recreated client stream type") + } + + return newClientStream.SendMsg(requestMsg) + } + + return nil + }) +} + +func (u *UnifiedOutput) handleBidiWrite(ctx context.Context, requestMsg *dynamic.Message, msg *service.Message) error { + // Use a fixed metadata key for session routing in minimal mode + sessionKey, _ := msg.MetaGet("session_id") + if sessionKey == "" { + sessionKey = "default" + } + + session := u.sessionMgr.GetSession(sessionKey) + if session == nil { + if err := u.createBidiStreamSession(ctx, sessionKey); err != nil { + return fmt.Errorf("failed to create bidi stream session: %w", err) + } + session = u.sessionMgr.GetSession(sessionKey) + } + + if session == nil { + return errors.New("failed to get bidi stream session") + } + + return WithContextRetry(ctx, u.retryConfig, func() error { + session.UpdateLastUse() + + bidiStream, ok := session.GetStream().(*grpcdynamic.BidiStream) + if !ok { + return errors.New("invalid bidi stream type") + } + + if err := bidiStream.SendMsg(requestMsg); err != nil { + // Remove failed session and retry will recreate it + u.sessionMgr.RemoveSession(sessionKey) + if recreateErr := u.createBidiStreamSession(ctx, sessionKey); recreateErr != nil { + return fmt.Errorf("failed to recreate bidi stream: %w", recreateErr) + } + + newSession := u.sessionMgr.GetSession(sessionKey) + if newSession == nil { + return errors.New("failed to get recreated bidi stream session") + } + + newBidiStream, ok := newSession.GetStream().(*grpcdynamic.BidiStream) + if !ok { + return errors.New("invalid recreated bidi stream type") + } + + return newBidiStream.SendMsg(requestMsg) + } + + return nil + }) +} + +func (u *UnifiedOutput) createClientStreamSession(ctx context.Context, sessionKey string) error { + conn, err := u.connMgr.GetConnection() + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + + stub := grpcdynamic.NewStub(conn) + + // Enhanced context handling for streaming + streamCtx := u.enhanceCallContext(ctx) + var cancel context.CancelFunc + + if u.cfg.CallTimeout > 0 { + streamCtx, cancel = context.WithTimeout(streamCtx, u.cfg.CallTimeout) + } else { + // Apply default timeout for streaming operations + defaultStreamTimeout := 10 * time.Minute + streamCtx, cancel = context.WithTimeout(streamCtx, defaultStreamTimeout) + } + + clientStream, err := stub.InvokeRpcClientStream(streamCtx, u.method) + if err != nil { + cancel() + return formatGrpcError("grpc_client failed to create client stream", u.method.GetFullyQualifiedName(), err) + } + + session := NewStreamSession(clientStream, cancel) + u.sessionMgr.SetSession(sessionKey, session) + + return nil +} + +func (u *UnifiedOutput) createBidiStreamSession(ctx context.Context, sessionKey string) error { + conn, err := u.connMgr.GetConnection() + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + + stub := grpcdynamic.NewStub(conn) + + // Enhanced context handling for bidirectional streaming + streamCtx := u.enhanceCallContext(ctx) + var cancel context.CancelFunc + + if u.cfg.CallTimeout > 0 { + streamCtx, cancel = context.WithTimeout(streamCtx, u.cfg.CallTimeout) + } else { + // Apply longer default timeout for bidirectional streaming + defaultBidiTimeout := 30 * time.Minute + streamCtx, cancel = context.WithTimeout(streamCtx, defaultBidiTimeout) + } + + bidiStream, err := stub.InvokeRpcBidiStream(streamCtx, u.method) + if err != nil { + cancel() + return formatGrpcError("grpc_client failed to create bidi stream", u.method.GetFullyQualifiedName(), err) + } + + session := NewStreamSession(bidiStream, cancel) + u.sessionMgr.SetSession(sessionKey, session) + + // Start response handler if configured + // Always log responses at debug in minimal mode if logger available + go u.handleBidiResponses(bidiStream, sessionKey) + + return nil +} + +func (u *UnifiedOutput) handleBidiResponses(bidiStream *grpcdynamic.BidiStream, sessionKey string) { + for { + resp, err := bidiStream.RecvMsg() + if err != nil { + // Log the error and exit + if u.sessionMgr != nil && u.sessionMgr.log != nil { + u.sessionMgr.log.With("session", sessionKey, "error", err).Debug("bidi response handler ended") + } + return + } + + // Handle different response types for logging + var respBytes []byte + var marshalErr error + + switch v := resp.(type) { + case *dynamic.Message: + respBytes, marshalErr = v.MarshalJSON() + case *structpb.Struct: + m := protojson.MarshalOptions{EmitUnpopulated: false, UseProtoNames: false, AllowPartial: true, Multiline: false, Indent: ""} + respBytes, marshalErr = m.Marshal(v) + default: + // Skip logging for unknown types + continue + } + + if marshalErr == nil && u.sessionMgr != nil && u.sessionMgr.log != nil { + u.sessionMgr.log.With("session", sessionKey).Debug(string(respBytes)) + } + } +} + +func (u *UnifiedOutput) Close(ctx context.Context) error { + u.mu.Lock() + u.shutdown = true + u.mu.Unlock() + + // Close session manager (this stops background goroutines) + if u.sessionMgr != nil { + u.sessionMgr.Close() + } + + // Close connection manager + if u.connMgr != nil { + return u.connMgr.Close() + } + + return nil +} + +func init() { + _ = service.RegisterOutput("grpc_client", genericOutputSpec(), func(conf *service.ParsedConfig, res *service.Resources) (service.Output, int, error) { + return newUnifiedOutput(conf, res) + }) +} diff --git a/internal/impl/grpc_client/output_grpc_client_test.go b/internal/impl/grpc_client/output_grpc_client_test.go new file mode 100644 index 000000000..bd1d812c1 --- /dev/null +++ b/internal/impl/grpc_client/output_grpc_client_test.go @@ -0,0 +1,41 @@ +package grpc_client + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc/metadata" +) + +func TestOutput_injectMetadataIntoContext(t *testing.T) { + cfg := &Config{ + BearerToken: "secret", + AuthHeaders: map[string]string{"x-api-key": "k", "foo": "bar"}, + } + ctx := context.Background() + ctx = injectMetadataIntoContext(ctx, cfg) + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("expected outgoing metadata in context") + } + assertMD := func(k, want string) { + vals := md.Get(k) + if len(vals) == 0 || vals[0] != want { + t.Fatalf("metadata %s = %v, want %s", k, vals, want) + } + } + assertMD("authorization", "Bearer secret") + assertMD("x-api-key", "k") + assertMD("foo", "bar") +} + +func TestOutput_enhanceCallContext_NoOp(t *testing.T) { + u := &UnifiedOutput{cfg: &Config{}} + parent, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx := u.enhanceCallContext(parent) + if ctx == nil { + t.Fatal("expected non-nil context") + } +} diff --git a/internal/impl/grpc_client/session_manager_test.go b/internal/impl/grpc_client/session_manager_test.go new file mode 100644 index 000000000..e02c18dde --- /dev/null +++ b/internal/impl/grpc_client/session_manager_test.go @@ -0,0 +1,28 @@ +package grpc_client + +import ( + "context" + "testing" +) + +func TestSessionManager_SetGetRemove(t *testing.T) { + sm := NewSessionManager(0, 0, nil) + defer sm.Close() + + if sm.GetSession("k") != nil { + t.Fatal("expected no session initially") + } + + _, cancel := context.WithCancel(context.Background()) + s := NewStreamSession(&struct{}{}, cancel) + sm.SetSession("k", s) + + if sm.GetSession("k") == nil { + t.Fatal("expected session present after SetSession") + } + + sm.RemoveSession("k") + if sm.GetSession("k") != nil { + t.Fatal("expected session removed after RemoveSession") + } +} diff --git a/public/components/all/package.go b/public/components/all/package.go index cd8f4ca5c..370566f40 100644 --- a/public/components/all/package.go +++ b/public/components/all/package.go @@ -23,6 +23,7 @@ import ( _ "github.com/warpstreamlabs/bento/public/components/elasticsearch" _ "github.com/warpstreamlabs/bento/public/components/etcd" _ "github.com/warpstreamlabs/bento/public/components/gcp" + _ "github.com/warpstreamlabs/bento/public/components/grpc_client" _ "github.com/warpstreamlabs/bento/public/components/hdfs" _ "github.com/warpstreamlabs/bento/public/components/influxdb" _ "github.com/warpstreamlabs/bento/public/components/io" diff --git a/public/components/grpc_client/package.go b/public/components/grpc_client/package.go new file mode 100644 index 000000000..46571f06f --- /dev/null +++ b/public/components/grpc_client/package.go @@ -0,0 +1,6 @@ +package grpc_client + +import ( + // Bring in the internal generic gRPC plugin definitions. + _ "github.com/warpstreamlabs/bento/internal/impl/grpc_client" +)