From c54154da9e35cab25232314cf69ab9d78447f9a5 Mon Sep 17 00:00:00 2001 From: Peter C Date: Mon, 12 Sep 2016 12:31:07 -0700 Subject: [PATCH] Add Inverse Matrix caching in a Thread-Safe Lookup Tree (#36) * Add matrix inversion caching * Benchmark and Parallel Benchmark tests for Reconstruct --- inversion_tree.go | 160 +++++++++++++++++++++++++++++++++++++++++ inversion_tree_test.go | 125 ++++++++++++++++++++++++++++++++ reedsolomon.go | 69 +++++++++++++----- reedsolomon_test.go | 158 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 494 insertions(+), 18 deletions(-) create mode 100644 inversion_tree.go create mode 100644 inversion_tree_test.go diff --git a/inversion_tree.go b/inversion_tree.go new file mode 100644 index 0000000..c9d8ab2 --- /dev/null +++ b/inversion_tree.go @@ -0,0 +1,160 @@ +/** + * A thread-safe tree which caches inverted matrices. + * + * Copyright 2016, Peter Collins + */ + +package reedsolomon + +import ( + "errors" + "sync" +) + +// The tree uses a Reader-Writer mutex to make it thread-safe +// when accessing cached matrices and inserting new ones. +type inversionTree struct { + mutex *sync.RWMutex + root inversionNode +} + +type inversionNode struct { + matrix matrix + children []*inversionNode +} + +// newInversionTree initializes a tree for storing inverted matrices. +// Note that the root node is the identity matrix as it implies +// there were no errors with the original data. +func newInversionTree(dataShards, parityShards int) inversionTree { + identity, _ := identityMatrix(dataShards) + root := inversionNode{ + matrix: identity, + children: make([]*inversionNode, dataShards+parityShards), + } + return inversionTree{ + mutex: &sync.RWMutex{}, + root: root, + } +} + +// GetInvertedMatrix returns the cached inverted matrix or nil if it +// is not found in the tree keyed on the indices of invalid rows. +func (t inversionTree) GetInvertedMatrix(invalidIndices []int) matrix { + // Lock the tree for reading before accessing the tree. + t.mutex.RLock() + defer t.mutex.RUnlock() + + // If no invalid indices were give we should return the root + // identity matrix. + if len(invalidIndices) == 0 { + return t.root.matrix + } + + // Recursively search for the inverted matrix in the tree, passing in + // 0 as the parent index as we start at the root of the tree. + return t.root.getInvertedMatrix(invalidIndices, 0) +} + +// errAlreadySet is returned if the root node matrix is overwritten +var errAlreadySet = errors.New("the root node identity matrix is already set") + +// InsertInvertedMatrix inserts a new inverted matrix into the tree +// keyed by the indices of invalid rows. The total number of shards +// is required for creating the proper length lists of child nodes for +// each node. +func (t inversionTree) InsertInvertedMatrix(invalidIndices []int, matrix matrix, shards int) error { + // If no invalid indices were given then we are done because the + // root node is already set with the identity matrix. + if len(invalidIndices) == 0 { + return errAlreadySet + } + + if !matrix.IsSquare() { + return errNotSquare + } + + // Lock the tree for writing and reading before accessing the tree. + t.mutex.Lock() + defer t.mutex.Unlock() + + // Recursively create nodes for the inverted matrix in the tree until + // we reach the node to insert the matrix to. We start by passing in + // 0 as the parent index as we start at the root of the tree. + t.root.insertInvertedMatrix(invalidIndices, matrix, shards, 0) + + return nil +} + +func (n inversionNode) getInvertedMatrix(invalidIndices []int, parent int) matrix { + // Get the child node to search next from the list of children. The + // list of children starts relative to the parent index passed in + // because the indices of invalid rows is sorted (by default). As we + // search recursively, the first invalid index gets popped off the list, + // so when searching through the list of children, use that first invalid + // index to find the child node. + firstIndex := invalidIndices[0] + node := n.children[firstIndex-parent] + + // If the child node doesn't exist in the list yet, fail fast by + // returning, so we can construct and insert the proper inverted matrix. + if node == nil { + return nil + } + + // If there's more than one invalid index left in the list we should + // keep searching recursively. + if len(invalidIndices) > 1 { + // Search recursively on the child node by passing in the invalid indices + // with the first index popped off the front. Also the parent index to + // pass down is the first index plus one. + return node.getInvertedMatrix(invalidIndices[1:], firstIndex+1) + } + // If there aren't any more invalid indices to search, we've found our + // node. Return it, however keep in mind that the matrix could still be + // nil because intermediary nodes in the tree are created sometimes with + // their inversion matrices uninitialized. + return node.matrix +} + +func (n inversionNode) insertInvertedMatrix(invalidIndices []int, matrix matrix, shards, parent int) { + // As above, get the child node to search next from the list of children. + // The list of children starts relative to the parent index passed in + // because the indices of invalid rows is sorted (by default). As we + // search recursively, the first invalid index gets popped off the list, + // so when searching through the list of children, use that first invalid + // index to find the child node. + firstIndex := invalidIndices[0] + node := n.children[firstIndex-parent] + + // If the child node doesn't exist in the list yet, create a new + // node because we have the writer lock and add it to the list + // of children. + if node == nil { + // Make the length of the list of children equal to the number + // of shards minus the first invalid index because the list of + // invalid indices is sorted, so only this length of errors + // are possible in the tree. + node = &inversionNode{ + children: make([]*inversionNode, shards-firstIndex), + } + // Insert the new node into the tree at the first index relative + // to the parent index that was given in this recursive call. + n.children[firstIndex-parent] = node + } + + // If there's more than one invalid index left in the list we should + // keep searching recursively in order to find the node to add our + // matrix. + if len(invalidIndices) > 1 { + // As above, search recursively on the child node by passing in + // the invalid indices with the first index popped off the front. + // Also the total number of shards and parent index are passed down + // which is equal to the first index plus one. + node.insertInvertedMatrix(invalidIndices[1:], matrix, shards, firstIndex+1) + } else { + // If there aren't any more invalid indices to search, we've found our + // node. Cache the inverted matrix in this node. + node.matrix = matrix + } +} diff --git a/inversion_tree_test.go b/inversion_tree_test.go new file mode 100644 index 0000000..49f5a17 --- /dev/null +++ b/inversion_tree_test.go @@ -0,0 +1,125 @@ +/** + * Unit tests for inversion tree. + * + * Copyright 2016, Peter Collins + */ + +package reedsolomon + +import ( + "testing" +) + +func TestNewInversionTree(t *testing.T) { + tree := newInversionTree(3, 2) + + children := len(tree.root.children) + if children != 5 { + t.Fatal("Root node children list length", children, "!=", 5) + } + + str := tree.root.matrix.String() + expect := "[[1, 0, 0], [0, 1, 0], [0, 0, 1]]" + if str != expect { + t.Fatal(str, "!=", expect) + } +} + +func TestGetInvertedMatrix(t *testing.T) { + tree := newInversionTree(3, 2) + + matrix := tree.GetInvertedMatrix([]int{}) + str := matrix.String() + expect := "[[1, 0, 0], [0, 1, 0], [0, 0, 1]]" + if str != expect { + t.Fatal(str, "!=", expect) + } + + matrix = tree.GetInvertedMatrix([]int{1}) + if matrix != nil { + t.Fatal(matrix, "!= nil") + } + + matrix = tree.GetInvertedMatrix([]int{1, 2}) + if matrix != nil { + t.Fatal(matrix, "!= nil") + } + + matrix, err := newMatrix(3, 3) + if err != nil { + t.Fatalf("Failed initializing new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{1}, matrix, 5) + if err != nil { + t.Fatalf("Failed inserting new Matrix : %s", err) + } + + cachedMatrix := tree.GetInvertedMatrix([]int{1}) + if cachedMatrix == nil { + t.Fatal(cachedMatrix, "== nil") + } + if matrix.String() != cachedMatrix.String() { + t.Fatal(matrix.String(), "!=", cachedMatrix.String()) + } +} + +func TestInsertInvertedMatrix(t *testing.T) { + tree := newInversionTree(3, 2) + + matrix, err := newMatrix(3, 3) + if err != nil { + t.Fatalf("Failed initializing new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{1}, matrix, 5) + if err != nil { + t.Fatalf("Failed inserting new Matrix : %s", err) + } + + err = tree.InsertInvertedMatrix([]int{}, matrix, 5) + if err == nil { + t.Fatal("Should have failed inserting the root node matrix", matrix) + } + + matrix, err = newMatrix(3, 2) + if err != nil { + t.Fatalf("Failed initializing new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{2}, matrix, 5) + if err == nil { + t.Fatal("Should have failed inserting a non-square matrix", matrix) + } + + matrix, err = newMatrix(3, 3) + if err != nil { + t.Fatalf("Failed initializing new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{0, 1}, matrix, 5) + if err != nil { + t.Fatalf("Failed inserting new Matrix : %s", err) + } +} + +func TestDoubleInsertInvertedMatrix(t *testing.T) { + tree := newInversionTree(3, 2) + + matrix, err := newMatrix(3, 3) + if err != nil { + t.Fatalf("Failed initializing new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{1}, matrix, 5) + if err != nil { + t.Fatalf("Failed inserting new Matrix : %s", err) + } + err = tree.InsertInvertedMatrix([]int{1}, matrix, 5) + if err != nil { + t.Fatalf("Failed inserting new Matrix : %s", err) + } + + cachedMatrix := tree.GetInvertedMatrix([]int{1}) + if cachedMatrix == nil { + t.Fatal(cachedMatrix, "== nil") + } + if matrix.String() != cachedMatrix.String() { + t.Fatal(matrix.String(), "!=", cachedMatrix.String()) + } +} diff --git a/reedsolomon.go b/reedsolomon.go index 0c98981..914ebe0 100644 --- a/reedsolomon.go +++ b/reedsolomon.go @@ -81,6 +81,7 @@ type reedSolomon struct { ParityShards int // Number of parity shards, should not be modified. Shards int // Total number of shards. Calculated, and should not be modified. m matrix + tree inversionTree parity [][]byte } @@ -128,6 +129,13 @@ func New(dataShards, parityShards int) (Encoder, error) { top, _ = top.Invert() r.m, _ = vm.Multiply(top) + // Inverted matrices are cached in a tree keyed by the indices + // of the invalid rows of the data to reconstruct. + // The inversion root node will have the identity matrix as + // its inversion matrix because it implies there are no errors + // with the original data. + r.tree = newInversionTree(dataShards, parityShards) + r.parity = make([][]byte, parityShards) for i := range r.parity { r.parity[i] = r.m[dataShards+i] @@ -380,36 +388,61 @@ func (r reedSolomon) Reconstruct(shards [][]byte) error { return ErrTooFewShards } - // Pull out the rows of the matrix that correspond to the - // shards that we have and build a square matrix. This - // matrix could be used to generate the shards that we have - // from the original data. - // - // Also, pull out an array holding just the shards that + // Pull out an array holding just the shards that // correspond to the rows of the submatrix. These shards // will be the input to the decoding process that re-creates // the missing data shards. - subMatrix, _ := newMatrix(r.DataShards, r.DataShards) + // + // Also, create an array of indices of the valid rows we do have + // and the invalid rows we don't have up until we have enough valid rows. subShards := make([][]byte, r.DataShards) + validIndices := make([]int, r.DataShards) + invalidIndices := make([]int, 0) subMatrixRow := 0 for matrixRow := 0; matrixRow < r.Shards && subMatrixRow < r.DataShards; matrixRow++ { if len(shards[matrixRow]) != 0 { - for c := 0; c < r.DataShards; c++ { - subMatrix[subMatrixRow][c] = r.m[matrixRow][c] - } subShards[subMatrixRow] = shards[matrixRow] + validIndices[subMatrixRow] = matrixRow subMatrixRow++ + } else { + invalidIndices = append(invalidIndices, matrixRow) } } - // Invert the matrix, so we can go from the encoded shards - // back to the original data. Then pull out the row that - // generates the shard that we want to decode. Note that - // since this matrix maps back to the original data, it can - // be used to create a data shard, but not a parity shard. - dataDecodeMatrix, err := subMatrix.Invert() - if err != nil { - return err + // Attempt to get the cached inverted matrix out of the tree + // based on the indices of the invalid rows. + dataDecodeMatrix := r.tree.GetInvertedMatrix(invalidIndices) + + // If the inverted matrix isn't cached in the tree yet we must + // construct it ourselves and insert it into the tree for the + // future. In this way the inversion tree is lazily loaded. + if dataDecodeMatrix == nil { + // Pull out the rows of the matrix that correspond to the + // shards that we have and build a square matrix. This + // matrix could be used to generate the shards that we have + // from the original data. + subMatrix, _ := newMatrix(r.DataShards, r.DataShards) + for subMatrixRow, validIndex := range validIndices { + for c := 0; c < r.DataShards; c++ { + subMatrix[subMatrixRow][c] = r.m[validIndex][c] + } + } + // Invert the matrix, so we can go from the encoded shards + // back to the original data. Then pull out the row that + // generates the shard that we want to decode. Note that + // since this matrix maps back to the original data, it can + // be used to create a data shard, but not a parity shard. + dataDecodeMatrix, err = subMatrix.Invert() + if err != nil { + return err + } + + // Cache the inverted matrix in the tree for future use keyed on the + // indices of the invalid rows. + err = r.tree.InsertInvertedMatrix(invalidIndices, dataDecodeMatrix, r.Shards) + if err != nil { + return err + } } // Re-create any data shards that were missing. diff --git a/reedsolomon_test.go b/reedsolomon_test.go index 61f8154..d9c876e 100644 --- a/reedsolomon_test.go +++ b/reedsolomon_test.go @@ -377,6 +377,164 @@ func BenchmarkVerify10x4x16M(b *testing.B) { benchmarkVerify(b, 10, 4, 16*1024*1024) } +func corruptRandom(shards [][]byte, dataShards, parityShards int) { + shardsToCorrupt := rand.Intn(parityShards) + for i := 1; i <= shardsToCorrupt; i++ { + shards[rand.Intn(dataShards+parityShards)] = nil + } +} + +func benchmarkReconstruct(b *testing.B, dataShards, parityShards, shardSize int) { + r, err := New(dataShards, parityShards) + if err != nil { + b.Fatal(err) + } + shards := make([][]byte, parityShards+dataShards) + for s := range shards { + shards[s] = make([]byte, shardSize) + } + + rand.Seed(0) + for s := 0; s < dataShards; s++ { + fillRandom(shards[s]) + } + err = r.Encode(shards) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(shardSize * dataShards)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + corruptRandom(shards, dataShards, parityShards) + + err = r.Reconstruct(shards) + if err != nil { + b.Fatal(err) + } + ok, err := r.Verify(shards) + if err != nil { + b.Fatal(err) + } + if !ok { + b.Fatal("Verification failed") + } + } +} + +// Benchmark 10 data slices with 2 parity slices holding 10000 bytes each +func BenchmarkReconstruct10x2x10000(b *testing.B) { + benchmarkReconstruct(b, 10, 2, 10000) +} + +// Benchmark 50 data slices with 5 parity slices holding 100000 bytes each +func BenchmarkReconstruct50x5x50000(b *testing.B) { + benchmarkReconstruct(b, 50, 5, 100000) +} + +// Benchmark 10 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstruct10x2x1M(b *testing.B) { + benchmarkReconstruct(b, 10, 2, 1024*1024) +} + +// Benchmark 5 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstruct5x2x1M(b *testing.B) { + benchmarkReconstruct(b, 5, 2, 1024*1024) +} + +// Benchmark 10 data slices with 4 parity slices holding 1MB bytes each +func BenchmarkReconstruct10x4x1M(b *testing.B) { + benchmarkReconstruct(b, 10, 4, 1024*1024) +} + +// Benchmark 5 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstruct50x20x1M(b *testing.B) { + benchmarkReconstruct(b, 50, 20, 1024*1024) +} + +// Benchmark 10 data slices with 4 parity slices holding 16MB bytes each +func BenchmarkReconstruct10x4x16M(b *testing.B) { + benchmarkReconstruct(b, 10, 4, 16*1024*1024) +} + +func benchmarkReconstructP(b *testing.B, dataShards, parityShards, shardSize int) { + r, err := New(dataShards, parityShards) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(shardSize * dataShards)) + runtime.GOMAXPROCS(runtime.NumCPU()) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + shards := make([][]byte, parityShards+dataShards) + for s := range shards { + shards[s] = make([]byte, shardSize) + } + + rand.Seed(0) + for s := 0; s < dataShards; s++ { + fillRandom(shards[s]) + } + err = r.Encode(shards) + if err != nil { + b.Fatal(err) + } + + for pb.Next() { + corruptRandom(shards, dataShards, parityShards) + + err = r.Reconstruct(shards) + if err != nil { + b.Fatal(err) + } + ok, err := r.Verify(shards) + if err != nil { + b.Fatal(err) + } + if !ok { + b.Fatal("Verification failed") + } + } + }) +} + +// Benchmark 10 data slices with 2 parity slices holding 10000 bytes each +func BenchmarkReconstructP10x2x10000(b *testing.B) { + benchmarkReconstructP(b, 10, 2, 10000) +} + +// Benchmark 50 data slices with 5 parity slices holding 100000 bytes each +func BenchmarkReconstructP50x5x50000(b *testing.B) { + benchmarkReconstructP(b, 50, 5, 100000) +} + +// Benchmark 10 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstructP10x2x1M(b *testing.B) { + benchmarkReconstructP(b, 10, 2, 1024*1024) +} + +// Benchmark 5 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstructP5x2x1M(b *testing.B) { + benchmarkReconstructP(b, 5, 2, 1024*1024) +} + +// Benchmark 10 data slices with 4 parity slices holding 1MB bytes each +func BenchmarkReconstructP10x4x1M(b *testing.B) { + benchmarkReconstructP(b, 10, 4, 1024*1024) +} + +// Benchmark 5 data slices with 2 parity slices holding 1MB bytes each +func BenchmarkReconstructP50x20x1M(b *testing.B) { + benchmarkReconstructP(b, 50, 20, 1024*1024) +} + +// Benchmark 10 data slices with 4 parity slices holding 16MB bytes each +func BenchmarkReconstructP10x4x16M(b *testing.B) { + benchmarkReconstructP(b, 10, 4, 16*1024*1024) +} + func TestEncoderReconstruct(t *testing.T) { // Create some sample data var data = make([]byte, 250000)