default to sslmode=disable on localhost

master
Oliver Tonnhofer 2013-07-29 09:20:17 +02:00
parent 1721a53000
commit 96220606f2
2 changed files with 36 additions and 5 deletions

View File

@ -403,6 +403,7 @@ func clusterTable(pg *PostGIS, tableName string, srid int, columns []ColumnSpec)
type PostGIS struct {
Db *sql.DB
Params string
Schema string
BackupSchema string
Config database.Config
@ -415,11 +416,7 @@ type PostGIS struct {
func (pg *PostGIS) Open() error {
var err error
params, err := pq.ParseURL(pg.Config.ConnectionParams)
if err != nil {
return err
}
pg.Db, err = sql.Open("postgres", params)
pg.Db, err = sql.Open("postgres", pg.Params)
if err != nil {
return err
}
@ -580,6 +577,7 @@ func (pg *PostGIS) NewTableTx(spec *TableSpec, bulkImport bool) *TableTx {
func New(conf database.Config, m *mapping.Mapping) (database.DB, error) {
db := &PostGIS{}
db.Tables = make(map[string]*TableSpec)
db.GeneralizedTables = make(map[string]*GeneralizedTableSpec)
@ -596,6 +594,7 @@ func New(conf database.Config, m *mapping.Mapping) (database.DB, error) {
if err != nil {
return nil, err
}
params = disableDefaultSslOnLocalhost(params)
db.Schema, db.BackupSchema = schemasFromConnectionParams(params)
db.Prefix = prefixFromConnectionParams(params)
@ -607,6 +606,7 @@ func New(conf database.Config, m *mapping.Mapping) (database.DB, error) {
}
db.prepareGeneralizedTableSources()
db.Params = params
err = db.Open()
if err != nil {
return nil, err

View File

@ -3,6 +3,7 @@ package postgis
import (
"database/sql"
"fmt"
"os"
"strings"
"sync"
)
@ -26,6 +27,36 @@ func schemasFromConnectionParams(params string) (string, string) {
return schema, backupSchema
}
// disableDefaultSslOnLocalhost adds sslmode=disable to params
// when host is localhost/127.0.0.1 and the sslmode param and
// PGSSLMODE environment are both not set.
func disableDefaultSslOnLocalhost(params string) string {
parts := strings.Fields(params)
isLocalHost := false
for _, p := range parts {
if strings.HasPrefix(p, "sslmode=") {
return params
}
if p == "host=localhost" || p == "host=127.0.0.1" {
isLocalHost = true
}
}
if !isLocalHost {
return params
}
for _, v := range os.Environ() {
parts := strings.SplitN(v, "=", 2)
if parts[0] == "PGSSLMODE" {
return params
}
}
// found localhost but explicit no sslmode, disable sslmode
return params + " sslmode=disable"
}
func prefixFromConnectionParams(params string) string {
parts := strings.Fields(params)
var prefix string