diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go index 7b19f343..78b15fbe 100644 --- a/tracelog/tracelog.go +++ b/tracelog/tracelog.go @@ -132,6 +132,7 @@ const ( tracelogBatchCtxKey tracelogCopyFromCtxKey tracelogConnectCtxKey + tracelogPrepareCtxKey ) type traceQueryData struct { @@ -282,6 +283,38 @@ func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEn } } +type tracePrepareData struct { + startTime time.Time + name string + sql string +} + +func (tl *TraceLog) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + return context.WithValue(ctx, tracelogPrepareCtxKey, &tracePrepareData{ + startTime: time.Now(), + name: data.Name, + sql: data.SQL, + }) +} + +func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + prepareData := ctx.Value(tracelogPrepareCtxKey).(*tracePrepareData) + + endTime := time.Now() + interval := endTime.Sub(prepareData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "time": interval, "alreadyPrepared": data.AlreadyPrepared}) + } +} + func (tl *TraceLog) shouldLog(lvl LogLevel) bool { return tl.LogLevel >= lvl } diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go index ae666066..96b79de9 100644 --- a/tracelog/tracelog_test.go +++ b/tracelog/tracelog_test.go @@ -41,6 +41,20 @@ func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg strin l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } +func (l *testLogger) Clear() { + l.logs = l.logs[0:0] +} + +func (l *testLogger) FilterByMsg(msg string) (res []testLog) { + for _, log := range l.logs { + if log.msg == msg { + res = append(res, log) + } + } + + return res +} + func TestContextGetsPassedToLogMethod(t *testing.T) { t.Parallel() @@ -58,7 +72,7 @@ func TestContextGetsPassedToLogMethod(t *testing.T) { } pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + logger.Clear() // Clear any logs written when establishing connection ctx = context.WithValue(context.Background(), "ctxdata", "foo") _, err := conn.Exec(ctx, `;`) @@ -120,20 +134,24 @@ func TestLogQuery(t *testing.T) { } pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + logger.Clear() // Clear any logs written when establishing connection _, err := conn.Exec(ctx, `select $1::text`, "testing") require.NoError(t, err) - require.Len(t, logger.logs, 1) - require.Equal(t, "Query", logger.logs[0].msg) - require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logs := logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() _, err = conn.Exec(ctx, `foo`, "testing") require.Error(t, err) - require.Len(t, logger.logs, 2) - require.Equal(t, "Query", logger.logs[1].msg) - require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) - require.Equal(t, err, logger.logs[1].data["err"]) + + logs = logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelError, logs[0].lvl) + require.Equal(t, err, logs[0].data["err"]) }) } @@ -155,7 +173,7 @@ func TestLogQueryArgsHandlesUTF8(t *testing.T) { } pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + logger.Clear() // Clear any logs written when establishing connection var s string for i := 0; i < 63; i++ { @@ -165,17 +183,21 @@ func TestLogQueryArgsHandlesUTF8(t *testing.T) { _, err := conn.Exec(ctx, `select $1::text`, s) require.NoError(t, err) - require.Len(t, logger.logs, 1) - require.Equal(t, "Query", logger.logs[0].msg) - require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) - require.Equal(t, s, logger.logs[0].data["args"].([]any)[0]) + + logs := logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + require.Equal(t, s, logs[0].data["args"].([]any)[0]) + + logger.Clear() _, err = conn.Exec(ctx, `select $1::text`, s+"000") require.NoError(t, err) - require.Len(t, logger.logs, 2) - require.Equal(t, "Query", logger.logs[1].msg) - require.Equal(t, tracelog.LogLevelInfo, logger.logs[1].lvl) - require.Equal(t, s+" (truncated 3 bytes)", logger.logs[1].data["args"].([]any)[0]) + + logs = logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + require.Equal(t, s+" (truncated 3 bytes)", logs[0].data["args"].([]any)[0]) }) } @@ -199,7 +221,7 @@ func TestLogCopyFrom(t *testing.T) { _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) require.NoError(t, err) - logger.logs = logger.logs[0:0] + logger.Clear() inputRows := [][]any{ {int32(1)}, @@ -209,11 +231,12 @@ func TestLogCopyFrom(t *testing.T) { copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) require.NoError(t, err) require.EqualValues(t, len(inputRows), copyCount) - require.Len(t, logger.logs, 1) - require.Equal(t, "CopyFrom", logger.logs[0].msg) - require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) - logger.logs = logger.logs[0:0] + logs := logger.FilterByMsg("CopyFrom") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() inputRows = [][]any{ {"not an integer"}, @@ -223,9 +246,10 @@ func TestLogCopyFrom(t *testing.T) { copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) require.Error(t, err) require.EqualValues(t, 0, copyCount) - require.Len(t, logger.logs, 1) - require.Equal(t, "CopyFrom", logger.logs[0].msg) - require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) + + logs = logger.FilterByMsg("CopyFrom") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelError, logs[0].lvl) }) } @@ -248,7 +272,7 @@ func TestLogConnect(t *testing.T) { require.Equal(t, "Connect", logger.logs[0].msg) require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) - logger.logs = logger.logs[0:0] + logger.Clear() config, err = pgx.ParseConfig("host=/invalid") require.NoError(t, err) @@ -279,7 +303,7 @@ func TestLogBatchStatementsOnExec(t *testing.T) { } pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + logger.Clear() // Clear any logs written when establishing connection batch := &pgx.Batch{} batch.Queue("create table foo (id bigint)") @@ -323,7 +347,7 @@ func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { } pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + logger.Clear() // Clear any logs written when establishing connection batch := &pgx.Batch{} batch.Queue("select generate_series(1,$1)", 100) @@ -341,3 +365,64 @@ func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { assert.Equal(t, "BatchClose", logger.logs[2].msg) }) } + +func TestLogPrepare(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + }, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + + logs := logger.FilterByMsg("Prepare") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() + + _, err = conn.Exec(ctx, `foo aaaa`, "testing") + require.Error(t, err) + + logs = logger.FilterByMsg("Prepare") + require.Len(t, logs, 1) + require.Equal(t, err, logs[0].data["err"]) + }) + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + _, err := conn.Prepare(ctx, "test_query_1", `select $1::int`) + require.NoError(t, err) + + require.Len(t, logger.logs, 1) + require.Equal(t, "Prepare", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.Clear() + + _, err = conn.Prepare(ctx, `test_query_2`, "foo aaaa") + require.Error(t, err) + + require.Len(t, logger.logs, 1) + require.Equal(t, "Prepare", logger.logs[0].msg) + require.Equal(t, err, logger.logs[0].data["err"]) + }) +}