Add specific error for streams, to help identify the faulty stream.

master
klauspost 2015-10-27 10:36:29 +01:00
parent 0500314cc5
commit e78a382960
2 changed files with 60 additions and 21 deletions

View File

@ -14,6 +14,7 @@ package reedsolomon
import (
"bytes"
"errors"
"fmt"
"io"
"sync"
)
@ -76,6 +77,25 @@ type StreamEncoder interface {
Join(dst io.Writer, shards []io.Reader, outSize int64) error
}
// StreamReadError is returned when a read error is encountered
// that relates to a supplied stream. This will allow you to
// find out which reader/writer has failed.
type StreamError struct {
Err error // The error
Stream int // The stream number on which the error occurred
Op string // Will be "read" or "write".
}
// Error returns the error as a string
func (s StreamError) Error() string {
return fmt.Sprintf("error on %s stream #%d: %s", s.Op, s.Stream, s.Err)
}
// String returns the error as a string
func (s StreamError) String() string {
return s.Error()
}
// reedSolomon contains a matrix for a specific
// distribution of datashards and parity shards.
// Construct if using New()
@ -215,13 +235,13 @@ func readShards(dst [][]byte, in []io.Reader) error {
size = n
} else if n != size {
// Shard sizes must match.
return ErrShardSize
return StreamError{Err: ErrShardSize, Op: "read", Stream: i}
}
dst[i] = dst[i][0:n]
case nil:
continue
default:
return err
return StreamError{Err: err, Op: "read", Stream: i}
}
}
if size == 0 {
@ -240,17 +260,18 @@ func writeShards(out []io.Writer, in [][]byte) error {
}
n, err := out[i].Write(in[i])
if err != nil {
return err
return StreamError{Err: err, Op: "write", Stream: i}
}
//
if n != len(in[i]) {
return io.ErrShortWrite
return StreamError{Err: io.ErrShortWrite, Op: "write", Stream: i}
}
}
return nil
}
type readResult struct {
n int
size int
err error
}
@ -275,14 +296,13 @@ func cReadShards(dst [][]byte, in []io.Reader) error {
// The error is EOF only if no bytes were read.
// If an EOF happens after reading some but not all the bytes,
// ReadFull returns ErrUnexpectedEOF.
res <- readResult{size: n, err: err}
res <- readResult{size: n, err: err, n: i}
}(i)
}
wg.Wait()
close(res)
size := -1
i := 0
for r := range res {
switch r.err {
case io.ErrUnexpectedEOF, io.EOF:
@ -290,14 +310,13 @@ func cReadShards(dst [][]byte, in []io.Reader) error {
size = r.size
} else if r.size != size {
// Shard sizes must match.
return ErrShardSize
return StreamError{Err: ErrShardSize, Op: "read", Stream: r.n}
}
dst[i] = dst[i][0:r.size]
dst[r.n] = dst[r.n][0:r.size]
case nil:
default:
return r.err
return StreamError{Err: r.err, Op: "read", Stream: r.n}
}
i++
}
if size == 0 {
return io.EOF
@ -322,11 +341,11 @@ func cWriteShards(out []io.Writer, in [][]byte) error {
}
n, err := out[i].Write(in[i])
if err != nil {
errs <- err
errs <- StreamError{Err: err, Op: "write", Stream: i}
return
}
if n != len(in[i]) {
errs <- io.ErrShortWrite
errs <- StreamError{Err: io.ErrShortWrite, Op: "write", Stream: i}
}
}(i)
}
@ -445,7 +464,7 @@ func (r rsStream) Join(dst io.Writer, shards []io.Reader, outSize int64) error {
shards = shards[:r.r.DataShards]
for i := range shards {
if shards[i] == nil {
return ErrShardNoData
return StreamError{Err: ErrShardNoData, Op: "read", Stream: i}
}
}
// Join all shards
@ -485,7 +504,7 @@ func (r rsStream) Split(data io.Reader, dst []io.Writer, size int64) error {
for i := range dst {
if dst[i] == nil {
return ErrShardNoData
return StreamError{Err: ErrShardNoData, Op: "write", Stream: i}
}
}
@ -500,7 +519,7 @@ func (r rsStream) Split(data io.Reader, dst []io.Writer, size int64) error {
for i := range dst {
n, err := io.CopyN(dst[i], data, perShard)
if err != io.EOF && err != nil {
return err
return StreamError{Err: err, Op: "write", Stream: i}
}
if n != perShard {
return ErrShortData

View File

@ -59,9 +59,17 @@ func TestStreamEncoding(t *testing.T) {
badShards := emptyBuffers(10)
badShards[0] = randomBuffer(123)
badShards[1] = randomBuffer(123)
err = r.Encode(toReaders(badShards), toWriters(emptyBuffers(3)))
if err != ErrShardSize {
t.Errorf("expected %v, got %v", ErrShardSize, err)
if se, ok := err.(StreamError); ok {
if se.Err != ErrShardSize {
t.Errorf("expected %v, got %v", ErrShardSize, se.Err)
}
if se.Stream <= 1 {
t.Errorf("expected stream no to be >=2, was %d", se.Stream)
}
} else {
t.Errorf("expected error type %T, got %T", StreamError{}, err)
}
}
@ -110,9 +118,17 @@ func TestStreamEncodingConcurrent(t *testing.T) {
badShards := emptyBuffers(10)
badShards[0] = randomBuffer(123)
badShards[1] = randomBuffer(123)
err = r.Encode(toReaders(badShards), toWriters(emptyBuffers(3)))
if err != ErrShardSize {
t.Errorf("expected %v, got %v", ErrShardSize, err)
if se, ok := err.(StreamError); ok {
if se.Err != ErrShardSize {
t.Errorf("expected %v, got %v", ErrShardSize, se.Err)
}
if se.Stream <= 1 {
t.Errorf("expected stream no to be >=2, was %d", se.Stream)
}
} else {
t.Errorf("expected error type %T, got %T", StreamError{}, err)
}
}
@ -598,8 +614,12 @@ func TestStreamSplitJoin(t *testing.T) {
bufs := toReaders(emptyBuffers(5))
bufs[2] = nil
err = enc.Join(buf, bufs, 0)
if err != ErrShardNoData {
t.Errorf("expected %v, got %v", ErrShardNoData, err)
if se, ok := err.(StreamError); ok {
if se.Err != ErrShardNoData {
t.Errorf("expected %v, got %v", ErrShardNoData, se.Err)
}
} else {
t.Errorf("expected error type %T, got %T", StreamError{}, err)
}
err = enc.Join(buf, toReaders(toBuffers(splits)), int64(len(data)+1))