validate all addresses resolved from hostname
Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
@@ -37,6 +37,7 @@ type Config struct {
|
|||||||
Password string
|
Password string
|
||||||
TLSConfig *tls.Config // nil disables TLS
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
BuildFrontend BuildFrontendFunc
|
BuildFrontend BuildFrontendFunc
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
@@ -77,7 +78,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
} else {
|
} else {
|
||||||
network = "tcp"
|
network = "tcp"
|
||||||
address = fmt.Sprintf("%s:%d", host, port)
|
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||||
}
|
}
|
||||||
return network, address
|
return network, address
|
||||||
}
|
}
|
||||||
@@ -190,6 +191,8 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
config.DialFunc = defaultDialer.DialContext
|
config.DialFunc = defaultDialer.DialContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||||
|
|
||||||
notRuntimeParams := map[string]struct{}{
|
notRuntimeParams := map[string]struct{}{
|
||||||
"host": struct{}{},
|
"host": struct{}{},
|
||||||
"port": struct{}{},
|
"port": struct{}{},
|
||||||
@@ -495,6 +498,10 @@ func makeDefaultDialer() *net.Dialer {
|
|||||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeDefaultResolver() *net.Resolver {
|
||||||
|
return net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||||
return func(r io.Reader, w io.Writer) Frontend {
|
return func(r io.Reader, w io.Writer) Frontend {
|
||||||
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ type Notification struct {
|
|||||||
// DialFunc is a function that can be used to connect to a PostgreSQL server.
|
// DialFunc is a function that can be used to connect to a PostgreSQL server.
|
||||||
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
// LookupFunc is a function that can be used to lookup IPs addrs from host.
|
||||||
|
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
|
||||||
|
|
||||||
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
||||||
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
|
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
|
||||||
|
|
||||||
@@ -123,6 +126,15 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
}
|
}
|
||||||
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
|
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
|
||||||
|
|
||||||
|
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fallbackConfigs) == 0 {
|
||||||
|
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
|
||||||
|
}
|
||||||
|
|
||||||
for _, fc := range fallbackConfigs {
|
for _, fc := range fallbackConfigs {
|
||||||
pgConn, err = connect(ctx, config, fc)
|
pgConn, err = connect(ctx, config, fc)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -147,6 +159,27 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
return pgConn, nil
|
return pgConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
|
||||||
|
var configs []*FallbackConfig
|
||||||
|
|
||||||
|
for _, fb := range fallbacks {
|
||||||
|
ips, err := lookupFn(ctx, fb.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
configs = append(configs, &FallbackConfig{
|
||||||
|
Host: ip,
|
||||||
|
Port: fb.Port,
|
||||||
|
TLSConfig: fb.TLSConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
pgConn := new(PgConn)
|
pgConn := new(PgConn)
|
||||||
pgConn.config = config
|
pgConn.config = config
|
||||||
|
|||||||
Reference in New Issue
Block a user