2
0

Add BuildFrontendFunc in Config

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-08-08 11:46:25 +03:00
parent fa7e06489b
commit 0a99b543c0
3 changed files with 36 additions and 30 deletions
+19 -11
View File
@@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math"
"net"
@@ -18,6 +19,7 @@ import (
"time"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
)
@@ -26,20 +28,18 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
BuildFrontendFunc BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
// MinReadBufferSize used to configure size of connection read buffer.
MinReadBufferSize int
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used validate that server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with
@@ -476,6 +476,14 @@ func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultBuildFrontendFunc() BuildFrontendFunc {
return func(r io.Reader) Frontend {
frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil)
return frontend
}
}
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
-4
View File
@@ -8,8 +8,6 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
@@ -25,7 +23,5 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+17 -15
View File
@@ -15,7 +15,6 @@ import (
"sync"
"time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgconn/internal/ctxwatch"
"github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2"
@@ -41,9 +40,12 @@ type Notification struct {
Payload string
}
// DialFunc is a function that can be used to connect to a PostgreSQL server
// DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader) Frontend
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@@ -56,6 +58,11 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)
}
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
@@ -63,7 +70,7 @@ type PgConn struct {
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
frontend Frontend
Config *Config
@@ -106,6 +113,9 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
if config.DialFunc == nil {
config.DialFunc = makeDefaultDialer().DialContext
}
if config.BuildFrontendFunc == nil {
config.BuildFrontendFunc = makeDefaultBuildFrontendFunc()
}
if config.RuntimeParams == nil {
config.RuntimeParams = make(map[string]string)
}
@@ -171,15 +181,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
func() { pgConn.conn.SetDeadline(time.Time{}) },
)
cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize})
if err != nil {
return nil, err
}
pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn)
if err != nil {
return nil, err
}
pgConn.frontend = config.BuildFrontendFunc(pgConn.conn)
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
@@ -298,7 +300,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive()
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
@@ -318,10 +320,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
// If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
msg, err = pgConn.Frontend.Receive()
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.Frontend.Receive()
msg, err = pgConn.frontend.Receive()
}
if err != nil {