concurrency: STM snapshot isolation level

release-3.2
Anthony Romano 2016-12-29 20:19:08 -08:00
parent 8604d1863b
commit 8695511153
1 changed files with 70 additions and 37 deletions

View File

@ -15,6 +15,8 @@
package concurrency
import (
"math"
v3 "github.com/coreos/etcd/clientv3"
"golang.org/x/net/context"
)
@ -82,7 +84,7 @@ func WithPrefetch(keys ...string) stmOption {
return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
}
// NewSTM initiates a new STM instance.
// NewSTM initiates a new STM instance, using snapshot isolation by default.
func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
opts := &stmOptions{ctx: c.Ctx()}
for _, f := range so {
@ -95,22 +97,38 @@ func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnRespon
return f(s)
}
}
var s STM
return runSTM(mkSTM(c, opts), apply)
}
func mkSTM(c *v3.Client, opts *stmOptions) STM {
switch opts.iso {
case Serializable:
s = &stmSerializable{
case Snapshot:
s := &stmSerializable{
stm: stm{client: c, ctx: opts.ctx},
prefetch: make(map[string]*v3.GetResponse),
}
s.conflicts = func() []v3.Cmp {
return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
}
return s
case Serializable:
s := &stmSerializable{
stm: stm{client: c, ctx: opts.ctx},
prefetch: make(map[string]*v3.GetResponse),
}
s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
return s
case RepeatableReads:
s = &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
return s
case ReadCommitted:
ss := stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
s = &stmReadCommitted{ss}
s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
s.conflicts = func() []v3.Cmp { return nil }
return s
default:
panic("unsupported")
panic("unsupported stm")
}
return runSTM(s, apply)
}
type stmResponse struct {
@ -152,11 +170,13 @@ type stm struct {
client *v3.Client
ctx context.Context
// rset holds read key values and revisions
rset map[string]*v3.GetResponse
rset readSet
// wset holds overwritten keys and their values
wset writeSet
// getOpts are the opts used for gets
getOpts []v3.OpOption
// conflicts computes the current conflicts on the txn
conflicts func() []v3.Cmp
}
type stmPut struct {
@ -164,6 +184,33 @@ type stmPut struct {
op v3.Op
}
type readSet map[string]*v3.GetResponse
func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
for i, resp := range txnresp.Responses {
rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
}
}
func (rs readSet) first() int64 {
ret := int64(math.MaxInt64 - 1)
for _, resp := range rs {
if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
ret = resp.Kvs[0].ModRevision
}
}
return ret
}
// cmps guards the txn from updates to read set
func (rs readSet) cmps() []v3.Cmp {
cmps := make([]v3.Cmp, 0, len(rs))
for k, rk := range rs {
cmps = append(cmps, isKeyCurrent(k, rk))
}
return cmps
}
type writeSet map[string]stmPut
func (ws writeSet) get(keys ...string) *stmPut {
@ -175,6 +222,15 @@ func (ws writeSet) get(keys ...string) *stmPut {
return nil
}
// cmps returns a cmp list testing no writes have happened past rev
func (ws writeSet) cmps(rev int64) []v3.Cmp {
cmps := make([]v3.Cmp, 0, len(ws))
for key := range ws {
cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
}
return cmps
}
// puts is the list of ops for all pending writes
func (ws writeSet) puts() []v3.Op {
puts := make([]v3.Op, 0, len(ws))
@ -205,7 +261,7 @@ func (s *stm) Rev(key string) int64 {
}
func (s *stm) commit() *v3.TxnResponse {
txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...).Commit()
txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
if err != nil {
panic(stmError{err})
}
@ -215,15 +271,6 @@ func (s *stm) commit() *v3.TxnResponse {
return nil
}
// cmps guards the txn from updates to read set
func (s *stm) cmps() []v3.Cmp {
cmps := make([]v3.Cmp, 0, len(s.rset))
for k, rk := range s.rset {
cmps = append(cmps, isKeyCurrent(k, rk))
}
return cmps
}
func (s *stm) fetch(keys ...string) *v3.GetResponse {
if len(keys) == 0 {
return nil
@ -239,7 +286,7 @@ func (s *stm) fetch(keys ...string) *v3.GetResponse {
if err != nil {
panic(stmError{err})
}
addTxnResp(s.rset, keys, txnresp)
s.rset.add(keys, txnresp)
return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
}
@ -292,7 +339,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) {
func (s *stmSerializable) commit() *v3.TxnResponse {
keys, getops := s.gets()
txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...)
txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
// use Else to prefetch keys in case of conflict to save a round trip
txnresp, err := txn.Else(getops...).Commit()
if err != nil {
@ -302,26 +349,12 @@ func (s *stmSerializable) commit() *v3.TxnResponse {
return txnresp
}
// load prefetch with Else data
addTxnResp(s.rset, keys, txnresp)
s.rset.add(keys, txnresp)
s.prefetch = s.rset
s.getOpts = nil
return nil
}
func addTxnResp(rset map[string]*v3.GetResponse, keys []string, txnresp *v3.TxnResponse) {
for i, resp := range txnresp.Responses {
rset[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
}
}
type stmReadCommitted struct{ stm }
// commit always goes through when read committed
func (s *stmReadCommitted) commit() *v3.TxnResponse {
s.rset = nil
return s.stm.commit()
}
func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
if len(r.Kvs) != 0 {
return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)