diff --git a/conn_test.go b/conn_test.go index e34662ae..467f6ecc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,7 +1,9 @@ package pgx_test import ( + "bytes" "context" + "log" "os" "strings" "sync" @@ -837,6 +839,37 @@ func TestLogPassesContext(t *testing.T) { } } +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { + return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.Logger = createAdapterFn(logger) + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() diff --git a/logger.go b/logger.go index 89fd5af5..19a74123 100644 --- a/logger.go +++ b/logger.go @@ -47,6 +47,14 @@ type Logger interface { Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) } +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + // LogLevelFromString converts log level string to constant // // Valid levels: