diff --git a/geom/geos/geos_test.go b/geom/geos/geos_test.go index 3f9ab99..112e117 100644 --- a/geom/geos/geos_test.go +++ b/geom/geos/geos_test.go @@ -1,6 +1,10 @@ package geos -import "testing" +import ( + "fmt" + + "testing" +) func TestFoo(t *testing.T) { _ = NewGeos() @@ -23,3 +27,63 @@ func BenchmarkWKB(b *testing.B) { g.Destroy(geom) } } + +func TestIndexQuery(t *testing.T) { + g := NewGeos() + defer g.Finish() + + idx := g.CreateIndex() + + for i := 0; i < 10; i++ { + p := g.FromWkt(fmt.Sprintf("POLYGON((%d 0, 10 0, 10 10, %d 10, %d 0))", i, i, i)) + if p == nil { + t.Fatal() + } + g.IndexAdd(idx, p) + } + + if geoms := g.IndexQuery(idx, g.Point(0, 10.000001)); len(geoms) != 0 { + t.Fatal(geoms) + } + + if geoms := g.IndexQuery(idx, g.Point(9.5, 5)); len(geoms) != 10 { + t.Fatal(geoms) + } + + if geoms := g.IndexQuery(idx, g.Point(0.5, 5)); len(geoms) != 1 { + t.Fatal(geoms) + } + if geoms := g.IndexQuery(idx, g.Point(4.5, 5)); len(geoms) != 5 { + t.Fatal(geoms) + } + +} + +func BenchmarkIndexQuery(b *testing.B) { + g := NewGeos() + defer g.Finish() + + idx := g.CreateIndex() + for i := 0; i < 10; i++ { + p := g.FromWkt(fmt.Sprintf("POLYGON((%d 0, 10 0, 10 10, %d 10, %d 0))", i, i, i)) + if p == nil { + b.Fatal() + } + g.IndexAdd(idx, p) + } + + for i := 0; i < b.N; i++ { + if geoms := g.IndexQuery(idx, g.Point(8.5, 5)); len(geoms) != 9 { + b.Fatal(geoms) + } + } + + // if geoms := g.IndexQuery(idx, g.Point(0, 0)); len(geoms) != 10 { + // b.Fatal(geoms) + // } + + // if geoms := g.IndexQuery(idx, g.Point(5, 5)); len(geoms) != 10 { + // b.Fatal(geoms) + // } + +} diff --git a/geom/geos/geos_wrap.go b/geom/geos/geos_wrap.go index 8e61cd4..b61d106 100644 --- a/geom/geos/geos_wrap.go +++ b/geom/geos/geos_wrap.go @@ -3,12 +3,13 @@ package geos /* #cgo LDFLAGS: -lgeos_c #include "geos_c.h" +#include #include #include #include +#include extern void goLogString(char *msg); -extern void goSendQueryResult(size_t, void *); void debug_wrap(const char *fmt, ...) { va_list a_list; @@ -31,9 +32,33 @@ void initGEOS_debug() { return initGEOS(devnull, debug_wrap); } -// wrap goIndexSendQueryResult -void IndexQuerySendCallback(void *item, void *userdata) { - goIndexSendQueryResult((size_t)item, userdata); +typedef struct { + uint32_t num; + uint32_t *arr; + uint32_t arrCap; +} queryResult; + +void queryResultAppend(queryResult *r, int idx) { + r->num += 1; + if (r->num >= r->arrCap) { + uint32_t newCap = r->arrCap > 0 ? r->arrCap * 2 : 8; + uint32_t *newArr = malloc(sizeof(uint32_t) * newCap); + if (r->arrCap == 0) { + r->arr = newArr; + } else { + memcpy(newArr, r->arr, r->num-1); + free(r->arr); + r->arr = newArr; + } + r->arrCap = newCap; + } + r->arr[r->num] = idx; +} + +void IndexQueryCallback(void *item, void *userdata) { + int idx = (size_t)item; + queryResult *result = (queryResult *)userdata; + queryResultAppend(result, idx); } void IndexAdd( @@ -47,14 +72,18 @@ void IndexAdd( GEOSSTRtree_insert_r(handle, tree, g, (void *)id); } + // query with our custom callback -void IndexQuery( +uint32_t *IndexQuery( GEOSContextHandle_t handle, GEOSSTRtree *tree, const GEOSGeometry *g, - void *userdata) + uint32_t *num) { - GEOSSTRtree_query_r(handle, tree, g, IndexQuerySendCallback, userdata); - } + queryResult result = {0}; + GEOSSTRtree_query_r(handle, tree, g, IndexQueryCallback, &result); + *num = result.num; + return result.arr; +} */ import "C" diff --git a/geom/geos/index.go b/geom/geos/index.go index 62dd951..4bfc1fa 100644 --- a/geom/geos/index.go +++ b/geom/geos/index.go @@ -4,10 +4,11 @@ package geos #cgo LDFLAGS: -lgeos_c #include "geos_c.h" #include +#include -extern void IndexQuerySendCallback(void *, void *); +extern void IndexQueryCallback(void *, void *); extern void goIndexSendQueryResult(size_t, void *); -extern void IndexQuery(GEOSContextHandle_t, GEOSSTRtree *, const GEOSGeometry *, void *); +extern uint32_t *IndexQuery(GEOSContextHandle_t, GEOSSTRtree *, const GEOSGeometry *, uint32_t *); extern void IndexAdd(GEOSContextHandle_t, GEOSSTRtree *, const GEOSGeometry *, size_t); */ @@ -47,23 +48,17 @@ func (this *Geos) IndexAdd(index *Index, geom *Geom) { // IndexQuery queries the index for intersections with geom. func (this *Geos) IndexQuery(index *Index, geom *Geom) []IndexGeom { - hits := make(chan int) - go func() { - // using a pointer to our hits chan to pass it through - // C.IndexQuerySendCallback (in C.IndexQuery) back - // to goIndexSendQueryResult - C.IndexQuery(this.v, index.v, geom.v, unsafe.Pointer(&hits)) - close(hits) - }() + var num C.uint32_t + r := C.IndexQuery(this.v, index.v, geom.v, &num) + if r == nil { + return nil + } + hits := (*[2 << 16]C.uint32_t)(unsafe.Pointer(r))[:num] + defer C.free(unsafe.Pointer(r)) + var geoms []IndexGeom for idx := range hits { geoms = append(geoms, index.geoms[idx]) } return geoms } - -//export goIndexSendQueryResult -func goIndexSendQueryResult(id C.size_t, ptr unsafe.Pointer) { - results := *(*chan int)(ptr) - results <- int(id) -}