diff --git a/leopard.go b/leopard.go index dc53a00..ffe3bc3 100644 --- a/leopard.go +++ b/leopard.go @@ -207,7 +207,43 @@ func (r *reedSolomonFF16) EncodeIdx(dataShard []byte, idx int, parity [][]byte) } func (r *reedSolomonFF16) Join(dst io.Writer, shards [][]byte, outSize int) error { - return errors.New("not implemented") + // Do we have enough shards? + if len(shards) < r.DataShards { + return ErrTooFewShards + } + shards = shards[:r.DataShards] + + // Do we have enough data? + size := 0 + for _, shard := range shards { + if shard == nil { + return ErrReconstructRequired + } + size += len(shard) + + // Do we have enough data already? + if size >= outSize { + break + } + } + if size < outSize { + return ErrShortData + } + + // Copy data to dst + write := outSize + for _, shard := range shards { + if write < len(shard) { + _, err := dst.Write(shard[:write]) + return err + } + n, err := dst.Write(shard) + if err != nil { + return err + } + write -= n + } + return nil } func (r *reedSolomonFF16) Update(shards [][]byte, newDatashards [][]byte) error { @@ -215,7 +251,46 @@ func (r *reedSolomonFF16) Update(shards [][]byte, newDatashards [][]byte) error } func (r *reedSolomonFF16) Split(data []byte) ([][]byte, error) { - return nil, errors.New("not implemented") + if len(data) == 0 { + return nil, ErrShortData + } + dataLen := len(data) + // Calculate number of bytes per data shard. + perShard := (len(data) + r.DataShards - 1) / r.DataShards + perShard = ((perShard + 63) / 64) * 64 + + if cap(data) > len(data) { + data = data[:cap(data)] + } + + // Only allocate memory if necessary + var padding []byte + if len(data) < (r.Shards * perShard) { + // calculate maximum number of full shards in `data` slice + fullShards := len(data) / perShard + padding = make([]byte, r.Shards*perShard-perShard*fullShards) + copy(padding, data[perShard*fullShards:]) + data = data[0 : perShard*fullShards] + } else { + for i := dataLen; i < dataLen+r.DataShards; i++ { + data[i] = 0 + } + } + + // Split into equal-length shards. + dst := make([][]byte, r.Shards) + i := 0 + for ; i < len(dst) && len(data) >= perShard; i++ { + dst[i] = data[:perShard:perShard] + data = data[perShard:] + } + + for j := 0; i+j < len(dst); j++ { + dst[i+j] = padding[:perShard:perShard] + padding = padding[perShard:] + } + + return dst, nil } func (r *reedSolomonFF16) ReconstructSome(shards [][]byte, required []bool) error { @@ -267,6 +342,29 @@ func (r *reedSolomonFF16) reconstruct(shards [][]byte, recoverAll bool) error { return err } + // Quick check: are all of the shards present? If so, there's + // nothing to do. + numberPresent := 0 + dataPresent := 0 + for i := 0; i < r.Shards; i++ { + if len(shards[i]) != 0 { + numberPresent++ + if i < r.DataShards { + dataPresent++ + } + } + } + if numberPresent == r.Shards || !recoverAll && dataPresent == r.DataShards { + // Cool. All of the shards data data. We don't + // need to do anything. + return nil + } + + // Check if we have enough to reconstruct. + if numberPresent < r.DataShards { + return ErrTooFewShards + } + shardSize := shardSize(shards) if shardSize%64 != 0 { return ErrShardSize diff --git a/leopard_test.go b/leopard_test.go new file mode 100644 index 0000000..bbe6346 --- /dev/null +++ b/leopard_test.go @@ -0,0 +1,167 @@ +package reedsolomon + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestEncoderReconstructLeo(t *testing.T) { + testEncoderReconstructLeo(t) +} + +func testEncoderReconstructLeo(t *testing.T, o ...Option) { + // Create some sample data + var data = make([]byte, 2<<20) + fillRandom(data) + + // Create 5 data slices of 50000 elements each + enc, err := New(500, 300, testOptions(o...)...) + if err != nil { + t.Fatal(err) + } + shards, err := enc.Split(data) + if err != nil { + t.Fatal(err) + } + err = enc.Encode(shards) + if err != nil { + t.Fatal(err) + } + + // Check that it verifies + ok, err := enc.Verify(shards) + if !ok || err != nil { + t.Fatal("not ok:", ok, "err:", err) + } + + // Delete a shard + shards[0] = nil + + // Should reconstruct + err = enc.Reconstruct(shards) + if err != nil { + t.Fatal(err) + } + + // Check that it verifies + ok, err = enc.Verify(shards) + if !ok || err != nil { + t.Fatal("not ok:", ok, "err:", err) + } + + // Recover original bytes + buf := new(bytes.Buffer) + err = enc.Join(buf, shards, len(data)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf.Bytes(), data) { + t.Fatal("recovered bytes do not match") + } + + // Corrupt a shard + shards[0] = nil + shards[1][0], shards[1][500] = 75, 75 + + // Should reconstruct (but with corrupted data) + err = enc.Reconstruct(shards) + if err != nil { + t.Fatal(err) + } + + // Check that it verifies + ok, err = enc.Verify(shards) + if ok || err != nil { + t.Fatal("error or ok:", ok, "err:", err) + } + + // Recovered data should not match original + buf.Reset() + err = enc.Join(buf, shards, len(data)) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(buf.Bytes(), data) { + t.Fatal("corrupted data matches original") + } +} + +func TestEncoderReconstructFailLeo(t *testing.T) { + // Create some sample data + var data = make([]byte, 2<<20) + fillRandom(data) + + // Create 5 data slices of 50000 elements each + enc, err := New(500, 300, testOptions()...) + if err != nil { + t.Fatal(err) + } + shards, err := enc.Split(data) + if err != nil { + t.Fatal(err) + } + err = enc.Encode(shards) + if err != nil { + t.Fatal(err) + } + + // Check that it verifies + ok, err := enc.Verify(shards) + if !ok || err != nil { + t.Fatal("not ok:", ok, "err:", err) + } + + // Delete more than parity shards + for i := 0; i < 301; i++ { + shards[i] = nil + } + + // Should not reconstruct + err = enc.Reconstruct(shards) + if err != ErrTooFewShards { + t.Fatal("want ErrTooFewShards, got:", err) + } +} + +func TestSplitJoinLeo(t *testing.T) { + var data = make([]byte, (250<<10)-1) + rand.Seed(0) + fillRandom(data) + + enc, _ := New(500, 300, testOptions()...) + shards, err := enc.Split(data) + if err != nil { + t.Fatal(err) + } + + _, err = enc.Split([]byte{}) + if err != ErrShortData { + t.Errorf("expected %v, got %v", ErrShortData, err) + } + + buf := new(bytes.Buffer) + err = enc.Join(buf, shards, 5000) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf.Bytes(), data[:5000]) { + t.Fatal("recovered data does match original") + } + + err = enc.Join(buf, [][]byte{}, 0) + if err != ErrTooFewShards { + t.Errorf("expected %v, got %v", ErrTooFewShards, err) + } + + err = enc.Join(buf, shards, len(data)+500*64) + if err != ErrShortData { + t.Errorf("expected %v, got %v", ErrShortData, err) + } + + shards[0] = nil + err = enc.Join(buf, shards, len(data)) + if err != ErrReconstructRequired { + t.Errorf("expected %v, got %v", ErrReconstructRequired, err) + } +}