diff --git a/clientv3/kv.go b/clientv3/kv.go index 26197a8ed..4c7449088 100644 --- a/clientv3/kv.go +++ b/clientv3/kv.go @@ -111,8 +111,8 @@ func (kv *kv) Delete(ctx context.Context, key string, opts ...OpOption) (*Delete } func (kv *kv) Compact(ctx context.Context, rev int64) error { - r := &pb.CompactionRequest{Revision: rev} - _, err := kv.getRemote().Compact(ctx, r) + remote := kv.getRemote() + _, err := remote.Compact(ctx, &pb.CompactionRequest{Revision: rev}) if err == nil { return nil } @@ -121,7 +121,7 @@ func (kv *kv) Compact(ctx context.Context, rev int64) error { return rpctypes.Error(err) } - go kv.switchRemote(err) + go kv.switchRemote(remote, err) return rpctypes.Error(err) } @@ -135,6 +135,7 @@ func (kv *kv) Txn(ctx context.Context) Txn { func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) { for { var err error + remote := kv.getRemote() switch op.t { // TODO: handle other ops case tRange: @@ -145,21 +146,21 @@ func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) { r.SortTarget = pb.RangeRequest_SortTarget(op.sort.Target) } - resp, err = kv.getRemote().Range(ctx, r) + resp, err = remote.Range(ctx, r) if err == nil { return OpResponse{get: (*GetResponse)(resp)}, nil } case tPut: var resp *pb.PutResponse r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID)} - resp, err = kv.getRemote().Put(ctx, r) + resp, err = remote.Put(ctx, r) if err == nil { return OpResponse{put: (*PutResponse)(resp)}, nil } case tDeleteRange: var resp *pb.DeleteRangeResponse r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end} - resp, err = kv.getRemote().DeleteRange(ctx, r) + resp, err = remote.DeleteRange(ctx, r) if err == nil { return OpResponse{del: (*DeleteResponse)(resp)}, nil } @@ -173,32 +174,32 @@ func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) { // do not retry on modifications if op.isWrite() { - go kv.switchRemote(err) + go kv.switchRemote(remote, err) return OpResponse{}, rpctypes.Error(err) } - if nerr := kv.switchRemote(err); nerr != nil { + if nerr := kv.switchRemote(remote, err); nerr != nil { return OpResponse{}, nerr } } } -func (kv *kv) switchRemote(prevErr error) error { - // Usually it's a bad idea to lock on network i/o but here it's OK - // since the link is down and new requests can't be processed anyway. - // Likewise, if connecting stalls, closing the Client can break the - // lock via context cancelation. +func (kv *kv) switchRemote(remote pb.KVClient, prevErr error) error { + kv.mu.Lock() + oldRemote := kv.remote + conn := kv.conn + kv.mu.Unlock() + if remote != oldRemote { + return nil + } + newConn, err := kv.c.retryConnection(conn, prevErr) kv.mu.Lock() defer kv.mu.Unlock() - - newConn, err := kv.c.retryConnection(kv.conn, prevErr) - if err != nil { - return rpctypes.Error(err) + if err == nil { + kv.conn = newConn + kv.remote = pb.NewKVClient(kv.conn) } - - kv.conn = newConn - kv.remote = pb.NewKVClient(kv.conn) - return nil + return rpctypes.Error(err) } func (kv *kv) getRemote() pb.KVClient { diff --git a/clientv3/txn.go b/clientv3/txn.go index 84ec4464e..875a325a0 100644 --- a/clientv3/txn.go +++ b/clientv3/txn.go @@ -141,8 +141,9 @@ func (txn *txn) Commit() (*TxnResponse, error) { kv := txn.kv for { + remote := kv.getRemote() r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas} - resp, err := kv.getRemote().Txn(txn.ctx, r) + resp, err := remote.Txn(txn.ctx, r) if err == nil { return (*TxnResponse)(resp), nil } @@ -152,11 +153,11 @@ func (txn *txn) Commit() (*TxnResponse, error) { } if txn.isWrite { - go kv.switchRemote(err) + go kv.switchRemote(remote, err) return nil, rpctypes.Error(err) } - if nerr := kv.switchRemote(err); nerr != nil { + if nerr := kv.switchRemote(remote, err); nerr != nil { return nil, nerr } }