165 lines
3.3 KiB
Go
165 lines
3.3 KiB
Go
package postgis
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
func schemasFromConnectionParams(params string) (string, string) {
|
|
parts := strings.Fields(params)
|
|
var schema, backupSchema string
|
|
for _, p := range parts {
|
|
if strings.HasPrefix(p, "schema=") {
|
|
schema = strings.Replace(p, "schema=", "", 1)
|
|
} else if strings.HasPrefix(p, "backupschema=") {
|
|
backupSchema = strings.Replace(p, "backupschema=", "", 1)
|
|
}
|
|
}
|
|
if schema == "" {
|
|
schema = "import"
|
|
}
|
|
if backupSchema == "" {
|
|
backupSchema = "backup"
|
|
}
|
|
return schema, backupSchema
|
|
}
|
|
|
|
// disableDefaultSslOnLocalhost adds sslmode=disable to params
|
|
// when host is localhost/127.0.0.1 and the sslmode param and
|
|
// PGSSLMODE environment are both not set.
|
|
func disableDefaultSslOnLocalhost(params string) string {
|
|
parts := strings.Fields(params)
|
|
isLocalHost := false
|
|
for _, p := range parts {
|
|
if strings.HasPrefix(p, "sslmode=") {
|
|
return params
|
|
}
|
|
if p == "host=localhost" || p == "host=127.0.0.1" {
|
|
isLocalHost = true
|
|
}
|
|
}
|
|
|
|
if !isLocalHost {
|
|
return params
|
|
}
|
|
|
|
for _, v := range os.Environ() {
|
|
parts := strings.SplitN(v, "=", 2)
|
|
if parts[0] == "PGSSLMODE" {
|
|
return params
|
|
}
|
|
}
|
|
|
|
// found localhost but explicit no sslmode, disable sslmode
|
|
return params + " sslmode=disable"
|
|
}
|
|
|
|
func prefixFromConnectionParams(params string) string {
|
|
parts := strings.Fields(params)
|
|
var prefix string
|
|
for _, p := range parts {
|
|
if strings.HasPrefix(p, "prefix=") {
|
|
prefix = strings.Replace(p, "prefix=", "", 1)
|
|
break
|
|
}
|
|
}
|
|
if prefix == "" {
|
|
prefix = "osm_"
|
|
}
|
|
if prefix[len(prefix)-1] != '_' {
|
|
prefix = prefix + "_"
|
|
}
|
|
return prefix
|
|
}
|
|
|
|
func tableExists(tx *sql.Tx, schema, table string) (bool, error) {
|
|
var exists bool
|
|
sql := fmt.Sprintf(`SELECT EXISTS(SELECT * FROM information_schema.tables WHERE table_name='%s' AND table_schema='%s')`,
|
|
table, schema)
|
|
row := tx.QueryRow(sql)
|
|
err := row.Scan(&exists)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
func dropTableIfExists(tx *sql.Tx, schema, table string) error {
|
|
exists, err := tableExists(tx, schema, table)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !exists {
|
|
return nil
|
|
}
|
|
sqlStmt := fmt.Sprintf("SELECT DropGeometryTable('%s', '%s');",
|
|
schema, table)
|
|
row := tx.QueryRow(sqlStmt)
|
|
var void interface{}
|
|
err = row.Scan(&void)
|
|
if err != nil {
|
|
return &SQLError{sqlStmt, err}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func rollbackIfTx(tx **sql.Tx) {
|
|
if *tx != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
log.Fatal("rollback failed", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// workerPool runs functions in n (worker) parallel goroutines.
|
|
// wait will return the first error or nil when all functions
|
|
// returned succesfull.
|
|
type workerPool struct {
|
|
in chan func() error
|
|
out chan error
|
|
wg *sync.WaitGroup
|
|
}
|
|
|
|
func newWorkerPool(worker, tasks int) *workerPool {
|
|
p := &workerPool{
|
|
make(chan func() error, tasks),
|
|
make(chan error, tasks),
|
|
&sync.WaitGroup{},
|
|
}
|
|
for i := 0; i < worker; i++ {
|
|
p.wg.Add(1)
|
|
go p.workerLoop()
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *workerPool) workerLoop() {
|
|
for f := range p.in {
|
|
p.out <- f()
|
|
}
|
|
p.wg.Done()
|
|
}
|
|
|
|
func (p *workerPool) wait() error {
|
|
close(p.in)
|
|
done := make(chan bool)
|
|
go func() {
|
|
p.wg.Wait()
|
|
done <- true
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case err := <-p.out:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case <-done:
|
|
return nil
|
|
}
|
|
}
|
|
}
|