Merge pull request #41 from georgysavva/add-config-copy
Add Config.Copy() method that returns a smart copy of the config.
This commit is contained in:
@@ -62,6 +62,35 @@ type Config struct {
|
|||||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||||
|
// The only exception is the TLSConfig field:
|
||||||
|
// according to the tls.Config docs it must not be modified after creation.
|
||||||
|
func (c *Config) Copy() *Config {
|
||||||
|
newConf := new(Config)
|
||||||
|
*newConf = *c
|
||||||
|
if newConf.TLSConfig != nil {
|
||||||
|
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
if newConf.RuntimeParams != nil {
|
||||||
|
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||||
|
for k, v := range c.RuntimeParams {
|
||||||
|
newConf.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newConf.Fallbacks != nil {
|
||||||
|
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||||
|
for i, fallback := range c.Fallbacks {
|
||||||
|
newFallback := new(FallbackConfig)
|
||||||
|
*newFallback = *fallback
|
||||||
|
if newFallback.TLSConfig != nil {
|
||||||
|
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
newConf.Fallbacks[i] = newFallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newConf
|
||||||
|
}
|
||||||
|
|
||||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||||
type FallbackConfig struct {
|
type FallbackConfig struct {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package pgconn_test
|
package pgconn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@@ -527,6 +528,44 @@ func TestParseConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyReturnsEqualConfig(t *testing.T) {
|
||||||
|
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
|
||||||
|
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
|
||||||
|
|
||||||
|
copied.Port = uint16(5433)
|
||||||
|
copied.RuntimeParams["foo"] = "bar"
|
||||||
|
copied.Fallbacks[0].Port = uint16(5433)
|
||||||
|
|
||||||
|
assert.Equal(t, uint16(5432), original.Port)
|
||||||
|
assert.Equal(t, "", original.RuntimeParams["foo"])
|
||||||
|
assert.Equal(t, uint16(5432), original.Fallbacks[0].Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
|
||||||
|
connString := os.Getenv("PGX_TEST_CONN_STRING")
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
_, err = pgconn.ConnectConfig(context.Background(), copied)
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||||
if !assert.NotNil(t, expected) {
|
if !assert.NotNil(t, expected) {
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user