From 8695511153ca9be7eeee0907d6378d9af11eaab9 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 29 Dec 2016 20:19:08 -0800 Subject: [PATCH] concurrency: STM snapshot isolation level --- clientv3/concurrency/stm.go | 107 +++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 37 deletions(-) diff --git a/clientv3/concurrency/stm.go b/clientv3/concurrency/stm.go index 12873262e..fd3ca6236 100644 --- a/clientv3/concurrency/stm.go +++ b/clientv3/concurrency/stm.go @@ -15,6 +15,8 @@ package concurrency import ( + "math" + v3 "github.com/coreos/etcd/clientv3" "golang.org/x/net/context" ) @@ -82,7 +84,7 @@ func WithPrefetch(keys ...string) stmOption { return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) } } -// NewSTM initiates a new STM instance. +// NewSTM initiates a new STM instance, using snapshot isolation by default. func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) { opts := &stmOptions{ctx: c.Ctx()} for _, f := range so { @@ -95,22 +97,38 @@ func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnRespon return f(s) } } - var s STM + return runSTM(mkSTM(c, opts), apply) +} + +func mkSTM(c *v3.Client, opts *stmOptions) STM { switch opts.iso { - case Serializable: - s = &stmSerializable{ + case Snapshot: + s := &stmSerializable{ stm: stm{client: c, ctx: opts.ctx}, prefetch: make(map[string]*v3.GetResponse), } + s.conflicts = func() []v3.Cmp { + return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...) + } + return s + case Serializable: + s := &stmSerializable{ + stm: stm{client: c, ctx: opts.ctx}, + prefetch: make(map[string]*v3.GetResponse), + } + s.conflicts = func() []v3.Cmp { return s.rset.cmps() } + return s case RepeatableReads: - s = &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} + s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} + s.conflicts = func() []v3.Cmp { return s.rset.cmps() } + return s case ReadCommitted: - ss := stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} - s = &stmReadCommitted{ss} + s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} + s.conflicts = func() []v3.Cmp { return nil } + return s default: - panic("unsupported") + panic("unsupported stm") } - return runSTM(s, apply) } type stmResponse struct { @@ -152,11 +170,13 @@ type stm struct { client *v3.Client ctx context.Context // rset holds read key values and revisions - rset map[string]*v3.GetResponse + rset readSet // wset holds overwritten keys and their values wset writeSet // getOpts are the opts used for gets getOpts []v3.OpOption + // conflicts computes the current conflicts on the txn + conflicts func() []v3.Cmp } type stmPut struct { @@ -164,6 +184,33 @@ type stmPut struct { op v3.Op } +type readSet map[string]*v3.GetResponse + +func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { + for i, resp := range txnresp.Responses { + rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) + } +} + +func (rs readSet) first() int64 { + ret := int64(math.MaxInt64 - 1) + for _, resp := range rs { + if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret { + ret = resp.Kvs[0].ModRevision + } + } + return ret +} + +// cmps guards the txn from updates to read set +func (rs readSet) cmps() []v3.Cmp { + cmps := make([]v3.Cmp, 0, len(rs)) + for k, rk := range rs { + cmps = append(cmps, isKeyCurrent(k, rk)) + } + return cmps +} + type writeSet map[string]stmPut func (ws writeSet) get(keys ...string) *stmPut { @@ -175,6 +222,15 @@ func (ws writeSet) get(keys ...string) *stmPut { return nil } +// cmps returns a cmp list testing no writes have happened past rev +func (ws writeSet) cmps(rev int64) []v3.Cmp { + cmps := make([]v3.Cmp, 0, len(ws)) + for key := range ws { + cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev)) + } + return cmps +} + // puts is the list of ops for all pending writes func (ws writeSet) puts() []v3.Op { puts := make([]v3.Op, 0, len(ws)) @@ -205,7 +261,7 @@ func (s *stm) Rev(key string) int64 { } func (s *stm) commit() *v3.TxnResponse { - txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...).Commit() + txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit() if err != nil { panic(stmError{err}) } @@ -215,15 +271,6 @@ func (s *stm) commit() *v3.TxnResponse { return nil } -// cmps guards the txn from updates to read set -func (s *stm) cmps() []v3.Cmp { - cmps := make([]v3.Cmp, 0, len(s.rset)) - for k, rk := range s.rset { - cmps = append(cmps, isKeyCurrent(k, rk)) - } - return cmps -} - func (s *stm) fetch(keys ...string) *v3.GetResponse { if len(keys) == 0 { return nil @@ -239,7 +286,7 @@ func (s *stm) fetch(keys ...string) *v3.GetResponse { if err != nil { panic(stmError{err}) } - addTxnResp(s.rset, keys, txnresp) + s.rset.add(keys, txnresp) return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange()) } @@ -292,7 +339,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) { func (s *stmSerializable) commit() *v3.TxnResponse { keys, getops := s.gets() - txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...) + txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...) // use Else to prefetch keys in case of conflict to save a round trip txnresp, err := txn.Else(getops...).Commit() if err != nil { @@ -302,26 +349,12 @@ func (s *stmSerializable) commit() *v3.TxnResponse { return txnresp } // load prefetch with Else data - addTxnResp(s.rset, keys, txnresp) + s.rset.add(keys, txnresp) s.prefetch = s.rset s.getOpts = nil return nil } -func addTxnResp(rset map[string]*v3.GetResponse, keys []string, txnresp *v3.TxnResponse) { - for i, resp := range txnresp.Responses { - rset[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) - } -} - -type stmReadCommitted struct{ stm } - -// commit always goes through when read committed -func (s *stmReadCommitted) commit() *v3.TxnResponse { - s.rset = nil - return s.stm.commit() -} - func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { if len(r.Kvs) != 0 { return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)