diff --git a/public/service/logger.go b/public/service/logger.go index 7f798c545a..0cfa218a52 100644 --- a/public/service/logger.go +++ b/public/service/logger.go @@ -2,10 +2,116 @@ package service import ( "fmt" + "os" "github.com/warpstreamlabs/bento/internal/log" ) +// airGapLogger adapts a LeveledLogger to implement log.Modular for use in Bento streams. +type airGapLogger struct { + l LeveledLogger +} + +func newAirGapLogger(logger LeveledLogger) log.Modular { + return &airGapLogger{l: logger} +} + +func (a *airGapLogger) WithFields(fields map[string]string) log.Modular { + if a.l == nil { + return nil + } + + switch t := a.l.(type) { + case interface { + WithFields(fields map[string]string) log.Modular + }: + return &airGapLogger{l: t.WithFields(fields)} + } + + return a.clone() +} + +func (a *airGapLogger) With(keyValues ...any) log.Modular { + if a.l == nil { + return nil + } + + switch t := a.l.(type) { + case interface { + With(keyValues ...any) log.Modular + }: + return &airGapLogger{l: t.With(keyValues...)} + } + + return a.clone() +} + +func (a *airGapLogger) Error(format string, v ...any) { + if a.l == nil { + return + } + a.l.Error(format, v...) +} +func (a *airGapLogger) Warn(format string, v ...any) { + if a.l == nil { + return + } + a.l.Warn(format, v...) +} +func (a *airGapLogger) Info(format string, v ...any) { + if a.l == nil { + return + } + a.l.Info(format, v...) +} +func (a *airGapLogger) Debug(format string, v ...any) { + if a.l == nil { + return + } + a.l.Debug(format, v...) +} +func (a *airGapLogger) Trace(format string, v ...any) { + if a.l == nil { + return + } + + switch fl := a.l.(type) { + case interface { + Trace(format string, v ...any) + }: + fl.Trace(format, v...) + return + } + // Logger does not implement Trace, so fallback to Debug. + a.l.Debug(format, v...) +} + +func (a *airGapLogger) Fatal(format string, v ...any) { + if a.l == nil { + return + } + + switch fl := a.l.(type) { + case interface { + Fatal(format string, v ...any) + }: + fl.Fatal(format, v...) + return + } + // Logger does not implement Fatal, so fallback to + // Error and exit with a status code 1. + a.l.Error(format, v...) + os.Exit(1) +} + +func (a *airGapLogger) clone() *airGapLogger { + if a.l == nil { + return nil + } + l := *a + return &l +} + // Logger allows plugin authors to write custom logs from components that are // exported the same way as native Bento logs. It's safe to pass around a nil // pointer for testing components. diff --git a/public/service/logger_test.go b/public/service/logger_test.go index 61287be6cd..3843bffba1 100644 --- a/public/service/logger_test.go +++ b/public/service/logger_test.go @@ -2,6 +2,8 @@ package service import ( "bytes" + "fmt" + "log/slog" "testing" "github.com/stretchr/testify/assert" @@ -66,3 +68,140 @@ func TestReverseAirGapLoggerDodgyFields(t *testing.T) { {"@service":"bento","field4":"value4","field5":"value5","level":"info","msg":"foo4"} `, buf.String()) } + +func TestAirGapLogger(t *testing.T) { + lConf := log.NewConfig() + lConf.AddTimeStamp = false + lConf.Format = "json" + + var buf bytes.Buffer + logger, err := log.New(&buf, ifs.OS(), lConf) + require.NoError(t, err) + + agLogger := newAirGapLogger(logger) + + agLogger.Error("foo: %v", "bar1") + agLogger.Warn("foo: %v", "bar2") + agLogger.Info("foo: %v", "bar3") + agLogger.Debug("foo: %v", "bar4") + + agLogger.With("key", "value").Info("log") + + assert.Equal(t, `{"@service":"bento","level":"error","msg":"foo: bar1"} +{"@service":"bento","level":"warning","msg":"foo: bar2"} +{"@service":"bento","level":"info","msg":"foo: bar3"} +{"@service":"bento","key":"value","level":"info","msg":"log"} +`, buf.String()) +} + +func TestAirGapSlogLogger(t *testing.T) { + var buf bytes.Buffer + slogLogger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + return slog.Attr{} + } + return a + }, + })) + + airGap := newAirGapLogger(slogLogger) + airGap.Info("test message") + + assert.Equal(t, `{"level":"INFO","msg":"test message"} +`, buf.String()) +} + +func TestAirGapLoggerChaining(t *testing.T) { + lConf := log.NewConfig() + lConf.AddTimeStamp = false + lConf.Format = "json" + + var buf bytes.Buffer + logger, err := log.New(&buf, ifs.OS(), lConf) + require.NoError(t, err) + + agLogger := newAirGapLogger(logger) + + agLogger.WithFields(map[string]string{"service": "test", "version": "1.0"}).Info("with fields") + agLogger.With("key1", "value1", "key2", "value2").Info("with context") + agLogger.WithFields(map[string]string{"service": "api"}).With("request_id", "123").Info("chained") + + chainedLogger := agLogger.WithFields(map[string]string{"service": "api", "version": "1.0"}) + chainedLogger2 := chainedLogger.WithFields(map[string]string{"env": "prod"}) + chainedLogger2.Info("multiple withfields") + + chainedLogger3 := agLogger.With("request_id", "123", "user_id", "456") + chainedLogger4 := chainedLogger3.With("action", "login") + chainedLogger4.Info("multiple with") + + chain1 := agLogger.WithFields(map[string]string{"chain": "1"}) + chain2 := agLogger.WithFields(map[string]string{"chain": "2"}) + chain1.Info("from chain 1") + chain2.Info("from chain 2") + + assert.Equal(t, `{"@service":"bento","level":"info","msg":"with fields","service":"test","version":"1.0"} +{"@service":"bento","key1":"value1","key2":"value2","level":"info","msg":"with context"} +{"@service":"bento","level":"info","msg":"chained","request_id":"123","service":"api"} +{"@service":"bento","env":"prod","level":"info","msg":"multiple withfields","service":"api","version":"1.0"} +{"@service":"bento","action":"login","level":"info","msg":"multiple with","request_id":"123","user_id":"456"} +{"@service":"bento","chain":"1","level":"info","msg":"from chain 1"} +{"@service":"bento","chain":"2","level":"info","msg":"from chain 2"} +`, buf.String()) +} + +func TestAirGapLoggerTrace(t *testing.T) { + lConf := log.NewConfig() + lConf.AddTimeStamp = false + lConf.Format = "json" + lConf.LogLevel = "trace" + + var buf bytes.Buffer + logger, err := log.New(&buf, ifs.OS(), lConf) + require.NoError(t, err) + + agLogger := newAirGapLogger(logger) + agLogger.Trace("trace message: %s", "test") + + assert.Equal(t, `{"@service":"bento","level":"trace","msg":"trace message: test"} +`, buf.String()) +} + +type mockBasicLogger struct { + logs *[]string +} + +func (m *mockBasicLogger) Error(format string, v ...any) { + *m.logs = append(*m.logs, fmt.Sprintf("ERROR: "+format, v...)) +} + +func (m *mockBasicLogger) Warn(format string, v ...any) { + *m.logs = append(*m.logs, fmt.Sprintf("WARN: "+format, v...)) +} + +func (m *mockBasicLogger) Info(format string, v ...any) { + *m.logs = append(*m.logs, fmt.Sprintf("INFO: "+format, v...)) +} + +func (m *mockBasicLogger) Debug(format string, v ...any) { + *m.logs = append(*m.logs, fmt.Sprintf("DEBUG: "+format, v...)) +} + +func TestAirGapLoggerNonChainingLogger(t *testing.T) { + var logs []string + basicLogger := &struct { + LeveledLogger + }{ + LeveledLogger: &mockBasicLogger{logs: &logs}, + } + + agLogger := newAirGapLogger(basicLogger) + + chainedLogger := agLogger.WithFields(map[string]string{"ignored": "field"}) + chainedLogger2 := chainedLogger.With("ignored", "context") + chainedLogger2.Info("test message") + + expected := []string{"INFO: test message"} + assert.Equal(t, expected, logs) +} diff --git a/public/service/service.go b/public/service/service.go index f1bda4ebe2..214eaa794c 100644 --- a/public/service/service.go +++ b/public/service/service.go @@ -38,7 +38,7 @@ func RunCLI(ctx context.Context, optFuncs ...CLIOptFunc) { cliOpts.outLoggerFn(&Logger{m: l}) } if cliOpts.teeLogger != nil { - return log.TeeLogger(l, log.NewBentoLogAdapter(cliOpts.teeLogger)), nil + return log.TeeLogger(l, newAirGapLogger(cliOpts.teeLogger)), nil } return l, nil } diff --git a/public/service/stream_builder.go b/public/service/stream_builder.go index 4c38636097..e8bc1327dd 100644 --- a/public/service/stream_builder.go +++ b/public/service/stream_builder.go @@ -162,12 +162,31 @@ func (s *StreamBuilder) SetPrintLogger(l PrintLogger) { s.customLogger = log.Wrap(l) } -// SetLogger sets a customer logger via Go's standard logging interface, +// LeveledLogger is an interface supported by most loggers. +type LeveledLogger interface { + Error(format string, v ...any) + Warn(format string, v ...any) + Info(format string, v ...any) + Debug(format string, v ...any) +} + +var ( + _ LeveledLogger = (*slog.Logger)(nil) + _ LeveledLogger = (log.Modular)(nil) +) + +// SetLogger sets a slog logger via Bento's standard logging interface, // allowing you to replace the default Bento logger with your own. func (s *StreamBuilder) SetLogger(l *slog.Logger) { s.customLogger = log.NewBentoLogAdapter(l) } +// SetLeveledLogger sets a custom logger via Bento's standard logging interface, +// allowing you to replace the default Bento logger with your own. +func (s *StreamBuilder) SetLeveledLogger(l LeveledLogger) { + s.customLogger = newAirGapLogger(l) +} + // HTTPMultiplexer is an interface supported by most HTTP multiplexers. type HTTPMultiplexer interface { HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request))