imposm3/database/postgis/tx.go

220 lines
3.9 KiB
Go
Raw Normal View History

2013-10-28 17:59:16 +04:00
package postgis
import (
"database/sql"
"fmt"
"sync"
)
2013-10-28 19:33:39 +04:00
type TableTx interface {
Begin(*sql.Tx) error
2013-10-28 19:33:39 +04:00
Insert(row []interface{}) error
Delete(id int64) error
End()
2013-10-28 19:33:39 +04:00
Commit() error
Rollback()
}
type tableTx struct {
2013-10-28 17:59:16 +04:00
Pg *PostGIS
Tx *sql.Tx
Table string
Spec *TableSpec
InsertStmt *sql.Stmt
DeleteStmt *sql.Stmt
InsertSql string
DeleteSql string
bulkImport bool
wg *sync.WaitGroup
rows chan []interface{}
}
2013-10-28 19:33:39 +04:00
func NewTableTx(pg *PostGIS, spec *TableSpec, bulkImport bool) TableTx {
tt := &tableTx{
2013-10-28 17:59:16 +04:00
Pg: pg,
2013-10-29 11:27:49 +04:00
Table: spec.FullName,
2013-10-28 17:59:16 +04:00
Spec: spec,
wg: &sync.WaitGroup{},
rows: make(chan []interface{}, 64),
bulkImport: bulkImport,
}
tt.wg.Add(1)
go tt.loop()
return tt
}
func (tt *tableTx) Begin(tx *sql.Tx) error {
var err error
if tx == nil {
tx, err = tt.Pg.Db.Begin()
if err != nil {
return err
}
2013-10-28 17:59:16 +04:00
}
tt.Tx = tx
if tt.bulkImport {
_, err = tx.Exec(fmt.Sprintf(`TRUNCATE TABLE "%s"."%s" RESTART IDENTITY`, tt.Pg.Schema, tt.Table))
if err != nil {
return err
}
}
if tt.bulkImport {
tt.InsertSql = tt.Spec.CopySQL()
} else {
tt.InsertSql = tt.Spec.InsertSQL()
}
stmt, err := tt.Tx.Prepare(tt.InsertSql)
if err != nil {
return &SQLError{tt.InsertSql, err}
}
tt.InsertStmt = stmt
if !tt.bulkImport {
// bulkImport creates COPY FROM STDIN stmt that doesn't
// permit other stmt
tt.DeleteSql = tt.Spec.DeleteSQL()
stmt, err = tt.Tx.Prepare(tt.DeleteSql)
if err != nil {
return &SQLError{tt.DeleteSql, err}
}
tt.DeleteStmt = stmt
}
return nil
}
2013-10-28 19:33:39 +04:00
func (tt *tableTx) Insert(row []interface{}) error {
2013-10-28 17:59:16 +04:00
tt.rows <- row
return nil
}
2013-10-28 19:33:39 +04:00
func (tt *tableTx) loop() {
2013-10-28 17:59:16 +04:00
for row := range tt.rows {
_, err := tt.InsertStmt.Exec(row...)
if err != nil {
// TODO
log.Fatal(&SQLInsertError{SQLError{tt.InsertSql, err}, row})
}
}
tt.wg.Done()
}
2013-10-28 19:33:39 +04:00
func (tt *tableTx) Delete(id int64) error {
2013-10-28 17:59:16 +04:00
if tt.bulkImport {
panic("unable to delete in bulkImport mode")
}
_, err := tt.DeleteStmt.Exec(id)
if err != nil {
return &SQLInsertError{SQLError{tt.DeleteSql, err}, id}
}
return nil
}
func (tt *tableTx) End() {
2013-10-28 17:59:16 +04:00
close(tt.rows)
tt.wg.Wait()
}
func (tt *tableTx) Commit() error {
tt.End()
2013-10-28 17:59:16 +04:00
if tt.bulkImport && tt.InsertStmt != nil {
_, err := tt.InsertStmt.Exec()
if err != nil {
return err
}
}
err := tt.Tx.Commit()
if err != nil {
return err
}
tt.Tx = nil
return nil
}
2013-10-28 19:33:39 +04:00
func (tt *tableTx) Rollback() {
rollbackIfTx(&tt.Tx)
}
type generalizedTableTx struct {
Pg *PostGIS
Tx *sql.Tx
Table string
Spec *GeneralizedTableSpec
InsertStmt *sql.Stmt
DeleteStmt *sql.Stmt
InsertSql string
DeleteSql string
}
func NewGeneralizedTableTx(pg *PostGIS, spec *GeneralizedTableSpec) TableTx {
tt := &generalizedTableTx{
Pg: pg,
2013-10-29 11:27:49 +04:00
Table: spec.FullName,
2013-10-28 19:33:39 +04:00
Spec: spec,
}
return tt
}
func (tt *generalizedTableTx) Begin(tx *sql.Tx) error {
var err error
if tx == nil {
tx, err = tt.Pg.Db.Begin()
if err != nil {
return err
}
2013-10-28 19:33:39 +04:00
}
tt.Tx = tx
tt.InsertSql = tt.Spec.InsertSQL()
stmt, err := tt.Tx.Prepare(tt.InsertSql)
if err != nil {
return &SQLError{tt.InsertSql, err}
}
tt.InsertStmt = stmt
tt.DeleteSql = tt.Spec.DeleteSQL()
stmt, err = tt.Tx.Prepare(tt.DeleteSql)
if err != nil {
return &SQLError{tt.DeleteSql, err}
}
tt.DeleteStmt = stmt
return nil
}
func (tt *generalizedTableTx) Insert(row []interface{}) error {
2013-10-29 11:27:49 +04:00
_, err := tt.InsertStmt.Exec(row[0])
if err != nil {
return &SQLInsertError{SQLError{tt.InsertSql, err}, row}
2013-10-28 19:33:39 +04:00
}
2013-10-29 11:27:49 +04:00
return nil
2013-10-28 19:33:39 +04:00
}
func (tt *generalizedTableTx) Delete(id int64) error {
_, err := tt.DeleteStmt.Exec(id)
if err != nil {
return &SQLInsertError{SQLError{tt.DeleteSql, err}, id}
}
return nil
}
func (tt *generalizedTableTx) End() {
}
2013-10-28 19:33:39 +04:00
func (tt *generalizedTableTx) Commit() error {
err := tt.Tx.Commit()
if err != nil {
return err
}
tt.Tx = nil
return nil
}
func (tt *generalizedTableTx) Rollback() {
2013-10-28 17:59:16 +04:00
rollbackIfTx(&tt.Tx)
}