imposm3/database/postgis/postgis.go

650 lines
14 KiB
Go
Raw Normal View History

2013-05-15 15:00:42 +04:00
package postgis
2013-05-06 21:14:37 +04:00
import (
"database/sql"
2013-05-14 18:15:35 +04:00
"errors"
2013-05-08 18:45:14 +04:00
"fmt"
2013-06-17 17:38:00 +04:00
pq "github.com/olt/pq"
2013-05-15 15:00:42 +04:00
"goposm/database"
2013-05-28 16:07:06 +04:00
"goposm/logging"
2013-05-14 18:15:35 +04:00
"goposm/mapping"
"runtime"
2013-05-08 18:45:14 +04:00
"strings"
2013-06-11 16:21:27 +04:00
"sync"
2013-05-06 21:14:37 +04:00
)
2013-05-28 16:07:06 +04:00
var log = logging.NewLogger("PostGIS")
2013-05-08 18:45:14 +04:00
type SQLError struct {
query string
originalError error
}
func (e *SQLError) Error() string {
return fmt.Sprintf("SQL Error: %s in query %s", e.originalError.Error(), e.query)
}
type SQLInsertError struct {
SQLError
data interface{}
}
func (e *SQLInsertError) Error() string {
return fmt.Sprintf("SQL Error: %s in query %s (%+v)", e.originalError.Error(), e.query, e.data)
}
func createTable(tx *sql.Tx, spec TableSpec) error {
2013-05-08 18:45:14 +04:00
var sql string
var err error
err = dropTableIfExists(tx, spec.Schema, spec.Name)
2013-05-06 21:14:37 +04:00
if err != nil {
return err
2013-05-06 21:14:37 +04:00
}
2013-05-08 18:45:14 +04:00
sql = spec.CreateTableSQL()
_, err = tx.Exec(sql)
2013-05-06 21:14:37 +04:00
if err != nil {
2013-05-08 18:45:14 +04:00
return &SQLError{sql, err}
2013-05-06 21:14:37 +04:00
}
err = addGeometryColumn(tx, spec.Name, spec)
if err != nil {
return err
}
return nil
}
func addGeometryColumn(tx *sql.Tx, tableName string, spec TableSpec) error {
2013-05-17 13:42:19 +04:00
geomType := strings.ToUpper(spec.GeometryType)
if geomType == "POLYGON" {
geomType = "GEOMETRY" // for multipolygon support
}
sql := fmt.Sprintf("SELECT AddGeometryColumn('%s', '%s', 'geometry', '%d', '%s', 2);",
spec.Schema, tableName, spec.Srid, geomType)
row := tx.QueryRow(sql)
var void interface{}
err := row.Scan(&void)
if err != nil {
return &SQLError{sql, err}
}
return nil
}
func populateGeometryColumn(tx *sql.Tx, tableName string, spec TableSpec) error {
sql := fmt.Sprintf("SELECT Populate_Geometry_Columns('%s.%s'::regclass);",
spec.Schema, tableName)
row := tx.QueryRow(sql)
2013-05-14 18:15:35 +04:00
var void interface{}
err := row.Scan(&void)
2013-05-06 21:14:37 +04:00
if err != nil {
2013-05-08 18:45:14 +04:00
return &SQLError{sql, err}
2013-05-06 21:14:37 +04:00
}
2013-05-08 18:45:14 +04:00
return nil
}
func (pg *PostGIS) createSchema(schema string) error {
2013-05-08 18:45:14 +04:00
var sql string
var err error
if schema == "public" {
2013-05-10 12:29:44 +04:00
return nil
}
sql = fmt.Sprintf("SELECT EXISTS(SELECT schema_name FROM information_schema.schemata WHERE schema_name = '%s');",
schema)
row := pg.Db.QueryRow(sql)
var exists bool
err = row.Scan(&exists)
if err != nil {
return &SQLError{sql, err}
}
if exists {
return nil
}
sql = fmt.Sprintf("CREATE SCHEMA \"%s\"", schema)
2013-05-08 18:45:14 +04:00
_, err = pg.Db.Exec(sql)
2013-05-06 21:14:37 +04:00
if err != nil {
2013-05-08 18:45:14 +04:00
return &SQLError{sql, err}
2013-05-06 21:14:37 +04:00
}
2013-05-08 18:45:14 +04:00
return nil
}
2013-05-06 21:14:37 +04:00
2013-05-14 18:15:35 +04:00
func (pg *PostGIS) InsertBatch(table string, rows [][]interface{}) error {
spec, ok := pg.Tables[table]
if !ok {
return errors.New("unkown table: " + table)
2013-05-06 21:14:37 +04:00
}
2013-05-08 18:45:14 +04:00
tx, err := pg.Db.Begin()
2013-05-06 21:14:37 +04:00
if err != nil {
2013-05-08 18:45:14 +04:00
return err
2013-05-06 21:14:37 +04:00
}
2013-05-22 11:49:03 +04:00
defer rollbackIfTx(&tx)
2013-05-06 21:14:37 +04:00
2013-05-08 18:45:14 +04:00
sql := spec.InsertSQL()
stmt, err := tx.Prepare(sql)
2013-05-06 21:14:37 +04:00
if err != nil {
2013-05-08 18:45:14 +04:00
return &SQLError{sql, err}
}
2013-05-14 18:15:35 +04:00
defer stmt.Close()
2013-05-08 18:45:14 +04:00
2013-05-14 18:15:35 +04:00
for _, row := range rows {
_, err := stmt.Exec(row...)
2013-05-08 18:45:14 +04:00
if err != nil {
2013-05-14 18:15:35 +04:00
return &SQLInsertError{SQLError{sql, err}, row}
2013-05-08 18:45:14 +04:00
}
2013-05-06 21:14:37 +04:00
}
2013-05-08 18:45:14 +04:00
err = tx.Commit()
if err != nil {
return err
}
2013-05-22 11:49:03 +04:00
tx = nil // set nil to prevent rollback
2013-05-08 18:45:14 +04:00
return nil
}
func (pg *PostGIS) Init() error {
if err := pg.createSchema(pg.Schema); err != nil {
2013-05-08 18:45:14 +04:00
return err
}
2013-05-14 18:15:35 +04:00
tx, err := pg.Db.Begin()
if err != nil {
return err
}
defer rollbackIfTx(&tx)
2013-05-14 18:15:35 +04:00
for _, spec := range pg.Tables {
if err := createTable(tx, *spec); err != nil {
2013-05-08 18:45:14 +04:00
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
tx = nil
2013-05-08 18:45:14 +04:00
return nil
}
func (pg *PostGIS) TableNames() []string {
var names []string
for name, _ := range pg.Tables {
names = append(names, name)
}
for name, _ := range pg.GeneralizedTables {
names = append(names, name)
}
return names
}
2013-05-22 11:49:03 +04:00
// Finish creates spatial indices on all tables.
func (pg *PostGIS) Finish() error {
2013-05-28 16:07:06 +04:00
defer log.StopStep(log.StartStep(fmt.Sprintf("Creating geometry indices")))
worker := int(runtime.NumCPU() / 2)
2013-06-20 16:23:06 +04:00
if worker < 1 {
worker = 1
}
2013-07-04 17:39:13 +04:00
p := newWorkerPool(worker, len(pg.Tables)+len(pg.GeneralizedTables))
for tableName, tbl := range pg.Tables {
2013-05-22 11:49:03 +04:00
tableName := pg.Prefix + tableName
table := tbl
p.in <- func() error {
2013-07-04 17:39:13 +04:00
return createIndex(pg, tableName, table.Columns)
2013-05-22 11:49:03 +04:00
}
}
for tableName, tbl := range pg.GeneralizedTables {
tableName := pg.Prefix + tableName
table := tbl
p.in <- func() error {
2013-07-04 17:39:13 +04:00
return createIndex(pg, tableName, table.Source.Columns)
}
}
2013-07-04 17:39:13 +04:00
err := p.wait()
if err != nil {
return err
}
2013-05-22 10:46:39 +04:00
return nil
}
2013-07-04 17:39:13 +04:00
func createIndex(pg *PostGIS, tableName string, columns []ColumnSpec) error {
for _, col := range columns {
if col.Type.Name() == "GEOMETRY" {
sql := fmt.Sprintf(`CREATE INDEX "%s_geom" ON "%s"."%s" USING GIST ("%s")`,
tableName, pg.Schema, tableName, col.Name)
step := log.StartStep(fmt.Sprintf("Creating geometry index on %s", tableName))
_, err := pg.Db.Exec(sql)
log.StopStep(step)
if err != nil {
return err
}
}
if col.FieldType.Name == "id" {
sql := fmt.Sprintf(`CREATE INDEX "%s_osm_id_idx" ON "%s"."%s" USING BTREE ("%s")`,
tableName, pg.Schema, tableName, col.Name)
step := log.StartStep(fmt.Sprintf("Creating OSM id index on %s", tableName))
_, err := pg.Db.Exec(sql)
log.StopStep(step)
if err != nil {
return err
}
}
}
return nil
}
2013-05-22 13:48:34 +04:00
func (pg *PostGIS) Generalize() error {
2013-05-28 16:07:06 +04:00
defer log.StopStep(log.StartStep(fmt.Sprintf("Creating generalized tables")))
worker := int(runtime.NumCPU() / 2)
2013-06-20 16:23:06 +04:00
if worker < 1 {
worker = 1
}
2013-05-22 13:48:34 +04:00
// generalized tables can depend on other generalized tables
// create tables with non-generalized sources first
p := newWorkerPool(worker, len(pg.GeneralizedTables))
2013-05-22 13:48:34 +04:00
for _, table := range pg.GeneralizedTables {
if table.SourceGeneralized == nil {
tbl := table // for following closure
p.in <- func() error {
if err := pg.generalizeTable(tbl); err != nil {
return err
}
tbl.created = true
return nil
2013-05-22 13:48:34 +04:00
}
}
}
err := p.wait()
if err != nil {
return err
}
2013-05-22 13:48:34 +04:00
// next create tables with created generalized sources until
// no new source is created
created := true
for created {
created = false
p := newWorkerPool(worker, len(pg.GeneralizedTables))
2013-05-22 13:48:34 +04:00
for _, table := range pg.GeneralizedTables {
if !table.created && table.SourceGeneralized.created {
tbl := table // for following closure
p.in <- func() error {
if err := pg.generalizeTable(tbl); err != nil {
return err
}
tbl.created = true
created = true
return nil
2013-05-22 13:48:34 +04:00
}
}
}
err := p.wait()
if err != nil {
return err
}
2013-05-22 13:48:34 +04:00
}
return nil
}
func (pg *PostGIS) generalizeTable(table *GeneralizedTableSpec) error {
2013-05-28 16:07:06 +04:00
defer log.StopStep(log.StartStep(fmt.Sprintf("Generalizing %s into %s",
pg.Prefix+table.SourceName, pg.Prefix+table.Name)))
2013-05-22 13:48:34 +04:00
tx, err := pg.Db.Begin()
if err != nil {
return err
}
defer rollbackIfTx(&tx)
var where string
if table.Where != "" {
where = " WHERE " + table.Where
}
var cols []string
for _, col := range table.Source.Columns {
cols = append(cols, col.Type.GeneralizeSql(&col, table))
}
if err := dropTableIfExists(tx, pg.Schema, table.Name); err != nil {
return err
}
columnSQL := strings.Join(cols, ",\n")
sql := fmt.Sprintf(`CREATE TABLE "%s"."%s" AS (SELECT %s FROM "%s"."%s"%s)`,
pg.Schema, table.Name, columnSQL, pg.Schema,
pg.Prefix+table.SourceName, where)
2013-05-28 16:07:06 +04:00
2013-05-22 13:48:34 +04:00
_, err = tx.Exec(sql)
if err != nil {
return &SQLError{sql, err}
}
err = populateGeometryColumn(tx, table.Name, *table.Source)
2013-05-22 13:48:34 +04:00
if err != nil {
return err
}
2013-05-22 13:48:34 +04:00
err = tx.Commit()
if err != nil {
return err
}
tx = nil // set nil to prevent rollback
return nil
}
2013-07-04 17:52:14 +04:00
// Optimize clusters tables on new GeoHash index.
2013-07-04 13:26:53 +04:00
func (pg *PostGIS) Optimize() error {
defer log.StopStep(log.StartStep(fmt.Sprintf("Clustering on geometry")))
worker := int(runtime.NumCPU() / 2)
if worker < 1 {
worker = 1
}
2013-07-04 17:52:14 +04:00
p := newWorkerPool(worker, len(pg.Tables)+len(pg.GeneralizedTables))
2013-07-04 13:26:53 +04:00
for tableName, tbl := range pg.Tables {
tableName := pg.Prefix + tableName
table := tbl
p.in <- func() error {
return clusterTable(pg, tableName, table.Srid, table.Columns)
}
}
for tableName, tbl := range pg.GeneralizedTables {
tableName := pg.Prefix + tableName
table := tbl
p.in <- func() error {
return clusterTable(pg, tableName, table.Source.Srid, table.Source.Columns)
}
}
2013-07-04 17:52:14 +04:00
err := p.wait()
2013-07-04 13:26:53 +04:00
if err != nil {
return err
}
return nil
}
2013-07-04 17:52:14 +04:00
func clusterTable(pg *PostGIS, tableName string, srid int, columns []ColumnSpec) error {
for _, col := range columns {
if col.Type.Name() == "GEOMETRY" {
step := log.StartStep(fmt.Sprintf("Indexing %s on geohash", tableName))
sql := fmt.Sprintf(`CREATE INDEX "%s_geom_geohash" ON "%s"."%s" (ST_GeoHash(ST_Transform(ST_SetSRID(Box2D(%s), %d), 4326)))`,
tableName, pg.Schema, tableName, col.Name, srid)
_, err := pg.Db.Exec(sql)
log.StopStep(step)
if err != nil {
return err
}
step = log.StartStep(fmt.Sprintf("Clustering %s on geohash", tableName))
sql = fmt.Sprintf(`CLUSTER "%s_geom_geohash" ON "%s"."%s"`,
tableName, pg.Schema, tableName)
_, err = pg.Db.Exec(sql)
log.StopStep(step)
if err != nil {
return err
}
break
}
}
return nil
}
2013-06-11 12:42:32 +04:00
type PostGIS struct {
Db *sql.DB
Params string
2013-06-11 12:42:32 +04:00
Schema string
BackupSchema string
Config database.Config
Tables map[string]*TableSpec
GeneralizedTables map[string]*GeneralizedTableSpec
Prefix string
InputBuffer *InsertBuffer
2013-06-11 12:42:32 +04:00
}
func (pg *PostGIS) Open() error {
var err error
pg.Db, err = sql.Open("postgres", pg.Params)
2013-06-11 12:42:32 +04:00
if err != nil {
return err
}
// check that the connection actually works
err = pg.Db.Ping()
if err != nil {
return err
}
return nil
}
func (pg *PostGIS) Insert(table string, row []interface{}) {
pg.InputBuffer.Insert(table, row)
}
func (pg *PostGIS) Delete(table string, id int64) error {
pg.InputBuffer.Delete(table, id)
return nil
}
func (pg *PostGIS) Begin() error {
2013-06-21 12:33:49 +04:00
pg.InputBuffer = NewInsertBuffer(pg, false)
return nil
}
func (pg *PostGIS) BeginBulk() error {
pg.InputBuffer = NewInsertBuffer(pg, true)
return nil
}
func (pg *PostGIS) Abort() error {
return pg.InputBuffer.Abort()
}
func (pg *PostGIS) End() error {
return pg.InputBuffer.End()
}
func (pg *PostGIS) Close() error {
return pg.Db.Close()
}
type TableTx struct {
2013-06-21 12:33:49 +04:00
Pg *PostGIS
Tx *sql.Tx
Table string
Spec *TableSpec
InsertStmt *sql.Stmt
DeleteStmt *sql.Stmt
InsertSql string
DeleteSql string
bulkImport bool
wg *sync.WaitGroup
rows chan []interface{}
}
func (tt *TableTx) Begin() error {
tx, err := tt.Pg.Db.Begin()
if err != nil {
return err
}
tt.Tx = tx
if tt.bulkImport {
_, err = tx.Exec(fmt.Sprintf(`TRUNCATE TABLE "%s"."%s" RESTART IDENTITY`, tt.Pg.Schema, tt.Table))
if err != nil {
return err
}
2013-06-11 16:43:14 +04:00
}
2013-06-21 12:33:49 +04:00
if tt.bulkImport {
tt.InsertSql = tt.Spec.CopySQL()
} else {
tt.InsertSql = tt.Spec.InsertSQL()
}
stmt, err := tt.Tx.Prepare(tt.InsertSql)
if err != nil {
2013-06-21 12:33:49 +04:00
return &SQLError{tt.InsertSql, err}
}
2013-06-21 12:33:49 +04:00
tt.InsertStmt = stmt
if !tt.bulkImport {
// bulkImport creates COPY FROM STDIN stmt that doesn't
// permit other stmt
tt.DeleteSql = tt.Spec.DeleteSQL()
stmt, err = tt.Tx.Prepare(tt.DeleteSql)
if err != nil {
return &SQLError{tt.DeleteSql, err}
}
tt.DeleteStmt = stmt
}
return nil
}
func (tt *TableTx) Insert(row []interface{}) error {
2013-06-11 16:21:27 +04:00
tt.rows <- row
return nil
}
2013-06-11 16:21:27 +04:00
func (tt *TableTx) loop() {
for row := range tt.rows {
2013-06-21 12:33:49 +04:00
_, err := tt.InsertStmt.Exec(row...)
2013-06-11 16:21:27 +04:00
if err != nil {
// TODO
2013-06-21 12:33:49 +04:00
log.Fatal(&SQLInsertError{SQLError{tt.InsertSql, err}, row})
2013-06-11 16:21:27 +04:00
}
}
tt.wg.Done()
}
func (tt *TableTx) Delete(id int64) error {
2013-06-21 12:33:49 +04:00
if tt.bulkImport {
panic("unable to delete in bulkImport mode")
}
2013-06-21 12:33:49 +04:00
_, err := tt.DeleteStmt.Exec(id)
if err != nil {
2013-06-21 12:33:49 +04:00
return &SQLInsertError{SQLError{tt.DeleteSql, err}, id}
}
return nil
}
func (tt *TableTx) Commit() error {
2013-06-11 16:21:27 +04:00
close(tt.rows)
tt.wg.Wait()
2013-06-21 12:33:49 +04:00
if tt.bulkImport && tt.InsertStmt != nil {
_, err := tt.InsertStmt.Exec()
2013-06-17 17:17:39 +04:00
if err != nil {
return err
}
}
err := tt.Tx.Commit()
if err != nil {
return err
}
tt.Tx = nil
return nil
}
func (tt *TableTx) Rollback() {
rollbackIfTx(&tt.Tx)
}
2013-06-21 12:33:49 +04:00
func (pg *PostGIS) NewTableTx(spec *TableSpec, bulkImport bool) *TableTx {
2013-06-11 16:21:27 +04:00
tt := &TableTx{
2013-06-21 12:33:49 +04:00
Pg: pg,
Table: spec.Name,
Spec: spec,
wg: &sync.WaitGroup{},
rows: make(chan []interface{}, 64),
bulkImport: bulkImport,
}
2013-06-11 16:21:27 +04:00
tt.wg.Add(1)
go tt.loop()
return tt
}
func New(conf database.Config, m *mapping.Mapping) (database.DB, error) {
2013-05-08 18:45:14 +04:00
db := &PostGIS{}
2013-05-14 18:15:35 +04:00
db.Tables = make(map[string]*TableSpec)
2013-05-22 13:48:34 +04:00
db.GeneralizedTables = make(map[string]*GeneralizedTableSpec)
2013-05-08 18:45:14 +04:00
db.Config = conf
if strings.HasPrefix(db.Config.ConnectionParams, "postgis://") {
db.Config.ConnectionParams = strings.Replace(
db.Config.ConnectionParams,
"postgis", "postgres", 1,
)
}
params, err := pq.ParseURL(db.Config.ConnectionParams)
if err != nil {
return nil, err
}
params = disableDefaultSslOnLocalhost(params)
db.Schema, db.BackupSchema = schemasFromConnectionParams(params)
db.Prefix = prefixFromConnectionParams(params)
for name, table := range m.Tables {
db.Tables[name] = NewTableSpec(db, table)
}
2013-05-22 13:48:34 +04:00
for name, table := range m.GeneralizedTables {
db.GeneralizedTables[name] = NewGeneralizedTableSpec(db, table)
}
2013-07-04 17:52:14 +04:00
db.prepareGeneralizedTableSources()
2013-05-22 13:48:34 +04:00
db.Params = params
err = db.Open()
2013-05-08 18:45:14 +04:00
if err != nil {
return nil, err
}
return db, nil
}
2013-05-15 15:00:42 +04:00
2013-07-04 17:52:14 +04:00
// prepareGeneralizedTableSources checks if all generalized table have an
// existing source and sets .Source to the original source (works even
// when source is allready generalized).
func (pg *PostGIS) prepareGeneralizedTableSources() {
for name, table := range pg.GeneralizedTables {
if source, ok := pg.Tables[table.SourceName]; ok {
table.Source = source
} else if source, ok := pg.GeneralizedTables[table.SourceName]; ok {
table.SourceGeneralized = source
} else {
log.Printf("missing source '%s' for generalized table '%s'\n",
table.SourceName, name)
}
}
// set source table until all generalized tables have a source
for filled := true; filled; {
filled = false
for _, table := range pg.GeneralizedTables {
if table.Source == nil {
if source, ok := pg.GeneralizedTables[table.SourceName]; ok && source.Source != nil {
table.Source = source.Source
}
filled = true
}
}
}
}
2013-05-15 15:00:42 +04:00
func init() {
database.Register("postgres", New)
database.Register("postgis", New)
}