From 96220606f234a3279ba7bf3e5e4738a5a16e7422 Mon Sep 17 00:00:00 2001 From: Oliver Tonnhofer Date: Mon, 29 Jul 2013 09:20:17 +0200 Subject: [PATCH] default to sslmode=disable on localhost --- database/postgis/postgis.go | 10 +++++----- database/postgis/util.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/database/postgis/postgis.go b/database/postgis/postgis.go index cf8ea31..8776719 100644 --- a/database/postgis/postgis.go +++ b/database/postgis/postgis.go @@ -403,6 +403,7 @@ func clusterTable(pg *PostGIS, tableName string, srid int, columns []ColumnSpec) type PostGIS struct { Db *sql.DB + Params string Schema string BackupSchema string Config database.Config @@ -415,11 +416,7 @@ type PostGIS struct { func (pg *PostGIS) Open() error { var err error - params, err := pq.ParseURL(pg.Config.ConnectionParams) - if err != nil { - return err - } - pg.Db, err = sql.Open("postgres", params) + pg.Db, err = sql.Open("postgres", pg.Params) if err != nil { return err } @@ -580,6 +577,7 @@ func (pg *PostGIS) NewTableTx(spec *TableSpec, bulkImport bool) *TableTx { func New(conf database.Config, m *mapping.Mapping) (database.DB, error) { db := &PostGIS{} + db.Tables = make(map[string]*TableSpec) db.GeneralizedTables = make(map[string]*GeneralizedTableSpec) @@ -596,6 +594,7 @@ func New(conf database.Config, m *mapping.Mapping) (database.DB, error) { if err != nil { return nil, err } + params = disableDefaultSslOnLocalhost(params) db.Schema, db.BackupSchema = schemasFromConnectionParams(params) db.Prefix = prefixFromConnectionParams(params) @@ -607,6 +606,7 @@ func New(conf database.Config, m *mapping.Mapping) (database.DB, error) { } db.prepareGeneralizedTableSources() + db.Params = params err = db.Open() if err != nil { return nil, err diff --git a/database/postgis/util.go b/database/postgis/util.go index 2cb7c55..fb8494b 100644 --- a/database/postgis/util.go +++ b/database/postgis/util.go @@ -3,6 +3,7 @@ package postgis import ( "database/sql" "fmt" + "os" "strings" "sync" ) @@ -26,6 +27,36 @@ func schemasFromConnectionParams(params string) (string, string) { return schema, backupSchema } +// disableDefaultSslOnLocalhost adds sslmode=disable to params +// when host is localhost/127.0.0.1 and the sslmode param and +// PGSSLMODE environment are both not set. +func disableDefaultSslOnLocalhost(params string) string { + parts := strings.Fields(params) + isLocalHost := false + for _, p := range parts { + if strings.HasPrefix(p, "sslmode=") { + return params + } + if p == "host=localhost" || p == "host=127.0.0.1" { + isLocalHost = true + } + } + + if !isLocalHost { + return params + } + + for _, v := range os.Environ() { + parts := strings.SplitN(v, "=", 2) + if parts[0] == "PGSSLMODE" { + return params + } + } + + // found localhost but explicit no sslmode, disable sslmode + return params + " sslmode=disable" +} + func prefixFromConnectionParams(params string) string { parts := strings.Fields(params) var prefix string