diff --git a/src/osd_rmw.cpp b/src/osd_rmw.cpp index aedc3a6b..eb168280 100644 --- a/src/osd_rmw.cpp +++ b/src/osd_rmw.cpp @@ -1084,3 +1084,162 @@ void calc_rmw_parity_ec(osd_rmw_stripe_t *stripes, int pg_size, int pg_minsize, } calc_rmw_parity_copy_parity(stripes, pg_size, pg_minsize, read_osd_set, write_osd_set, chunk_size, start, end); } + +// Generate subsets of k items each in {0..n-1} +static bool first_combination(int *subset, int k, int n) +{ + if (k > n) + return false; + for (int i = 0; i < k; i++) + subset[i] = i; + return true; +} + +static bool next_combination(int *subset, int k, int n) +{ + int pos = k-1; + while (true) + { + subset[pos]++; + if (subset[pos] >= n-(k-1-pos)) + { + if (pos == 0) + return false; + pos--; + } + else + break; + } + for (pos++; pos < k; pos++) + { + subset[pos] = subset[pos-1]+1; + } + return true; +} + +static int c_n_k(int n, int k) +{ + int c = 1; + for (int i = n; i > k; i--) + c *= i; + for (int i = 2; i <= (n-k); i++) + c /= i; + return c; +} + +std::vector ec_find_good(osd_rmw_stripe_t *stripes, int pg_size, int pg_minsize, bool is_xor, + uint32_t chunk_size, uint32_t bitmap_size, int max_bruteforce) +{ + std::vector found_valid; + int cur_live[pg_size], live_count = 0; + osd_num_t fake_osd_set[pg_size]; + for (int role = 0; role < pg_size; role++) + { + if (!stripes[role].missing) + { + cur_live[live_count++] = role; + fake_osd_set[role] = role+1; + } + } + if (live_count <= pg_minsize) + { + return std::vector(); + } + // Try to locate errors using brute force if there isn't too many combinations + osd_rmw_stripe_t brute_stripes[pg_size]; + int out_count = live_count-pg_minsize; + bool brute_force = out_count > 1 && c_n_k(live_count-1, out_count-1) <= max_bruteforce; + int subset[pg_minsize], outset[out_count]; + // Select all combinations with items except the last one (== anything to compare) + first_combination(subset, pg_minsize, live_count-1); + uint8_t *tmp_buf = (uint8_t*)malloc_or_die(pg_size*chunk_size); + do + { + memcpy(brute_stripes, stripes, sizeof(osd_rmw_stripe_t)*pg_size); + int i = 0, j = 0, k = 0; + for (; i < pg_minsize; i++, j++) + while (j < subset[i]) + outset[k++] = j++; + while (j < pg_size) + outset[k++] = j++; + for (int i = 0; i < out_count; i++) + { + brute_stripes[cur_live[outset[i]]].missing = true; + brute_stripes[cur_live[outset[i]]].read_buf = tmp_buf+cur_live[outset[i]]*chunk_size; + } + for (int i = 0; i < pg_minsize; i++) + { + brute_stripes[i].write_buf = brute_stripes[i].read_buf; + brute_stripes[i].req_start = 0; + brute_stripes[i].req_end = chunk_size; + } + for (int i = pg_minsize; i < pg_size; i++) + { + brute_stripes[i].write_buf = tmp_buf+i*chunk_size; + } + if (is_xor) + { + assert(pg_size == pg_minsize+1); + reconstruct_stripes_xor(brute_stripes, pg_size, bitmap_size); + } + else + { + reconstruct_stripes_ec(brute_stripes, pg_size, pg_minsize, bitmap_size); + calc_rmw_parity_ec(brute_stripes, pg_size, pg_minsize, fake_osd_set, fake_osd_set, chunk_size, bitmap_size); + } + for (int i = pg_minsize; i < pg_size; i++) + { + brute_stripes[i].read_buf = brute_stripes[i].write_buf; + } + int max_live = 0; + for (int i = 0; i < pg_size; i++) + { + if (!brute_stripes[i].missing) + { + max_live = i; + } + } + int valid_count = 0; + for (int i = 0; i < out_count; i++) + { + // Only compare with chunks after the last one so first N + each 1 after them don't repeat + // I.e. compare (1,2,3,4) with (5,6) and (1,2,3,5) only with (6) and so on + if (cur_live[outset[i]] > max_live && + memcmp(brute_stripes[cur_live[outset[i]]].read_buf, + stripes[cur_live[outset[i]]].read_buf, chunk_size) == 0) + { + brute_stripes[cur_live[outset[i]]].missing = false; + valid_count++; + } + } + if (valid_count > 0) + { + if (found_valid.size()) + { + // Ambiguity: we found multiple valid sets and don't know which one is correct + found_valid.clear(); + break; + } + for (int i = 0; i < pg_size; i++) + { + if (!brute_stripes[i].missing) + { + found_valid.push_back(i); + } + } + if (valid_count == out_count) + { + // All chunks are good + break; + } + } + if (!brute_force) + { + // Do not attempt brute force if there are too many combinations because even + // if we find it we won't be able to check that it's the only good one + break; + } + } while (out_count > 1 && next_combination(subset, pg_minsize, live_count-1)); + free(tmp_buf); + return found_valid; +} diff --git a/src/osd_rmw.h b/src/osd_rmw.h index 42c257fc..f09e663e 100644 --- a/src/osd_rmw.h +++ b/src/osd_rmw.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "object_id.h" #include "osd_id.h" @@ -54,3 +55,6 @@ void reconstruct_stripes_ec(osd_rmw_stripe_t *stripes, int pg_size, int pg_minsi void calc_rmw_parity_ec(osd_rmw_stripe_t *stripes, int pg_size, int pg_minsize, uint64_t *read_osd_set, uint64_t *write_osd_set, uint32_t chunk_size, uint32_t bitmap_size); + +std::vector ec_find_good(osd_rmw_stripe_t *stripes, int pg_size, int pg_minsize, bool is_xor, + uint32_t chunk_size, uint32_t bitmap_size, int max_bruteforce); diff --git a/src/osd_rmw_test.cpp b/src/osd_rmw_test.cpp index 07a5d552..45299a17 100644 --- a/src/osd_rmw_test.cpp +++ b/src/osd_rmw_test.cpp @@ -28,6 +28,7 @@ void test14(); void test15(bool second); void test16(); void test_recover_22_d2(); +void test_ec42_error_bruteforce(); int main(int narg, char *args[]) { @@ -64,6 +65,8 @@ int main(int narg, char *args[]) test16(); // Test 17 test_recover_22_d2(); + // Error bruteforce + test_ec42_error_bruteforce(); // End printf("all ok\n"); return 0; @@ -1106,3 +1109,64 @@ void test_recover_22_d2() // Done use_ec(4, 2, false); } + +/*** + +EC 4+2 error location bruteforce + +***/ + +static void assert_eq_vec(const std::vector & a, const std::vector & b) +{ + printf("Expect ["); + for (int i = 0; i < a.size(); i++) + printf(" %d", a[i]); + printf(" ] have ["); + for (int i = 0; i < b.size(); i++) + printf(" %d", b[i]); + printf(" ]\n"); + assert(a == b); +} + +void test_ec42_error_bruteforce() +{ + use_ec(6, 4, true); + osd_num_t osd_set[6] = { 1, 2, 3, 4, 5, 6 }; + osd_rmw_stripe_t stripes[6] = {}; + split_stripes(4, 4096, 0, 4096 * 4, stripes); + uint8_t *write_buf = (uint8_t*)malloc_or_die(4096 * 6); + set_pattern(write_buf+0*4096, 4096, PATTERN0); + set_pattern(write_buf+1*4096, 4096, PATTERN1); + set_pattern(write_buf+2*4096, 4096, PATTERN2); + set_pattern(write_buf+3*4096, 4096, PATTERN3); + uint8_t *rmw_buf = (uint8_t*)calc_rmw(write_buf, stripes, osd_set, 6, 4, 6, osd_set, 4096, 0); + calc_rmw_parity_ec(stripes, 6, 4, osd_set, osd_set, 4096, 0); + check_pattern(stripes[4].write_buf, 4096, PATTERN0^PATTERN1^PATTERN2^PATTERN3); + check_pattern(stripes[5].write_buf, 4096, 0x139274739ae6f387); // 2nd EC chunk + memcpy(write_buf+4*4096, stripes[4].write_buf, 4096); + memcpy(write_buf+5*4096, stripes[5].write_buf, 4096); + // Try to locate errors + for (int i = 0; i < 6; i++) + { + stripes[i].read_start = 0; + stripes[i].read_end = 4096; + stripes[i].read_buf = write_buf+i*4096; + stripes[i].write_buf = NULL; + } + // All good chunks + auto res = ec_find_good(stripes, 6, 4, false, 4096, 0, 100); + assert_eq_vec(res, std::vector({0, 1, 2, 3, 4, 5})); + // 1 missing chunk + set_pattern(write_buf+1*4096, 4096, 0); + res = ec_find_good(stripes, 6, 4, false, 4096, 0, 100); + assert_eq_vec(res, std::vector({0, 2, 3, 4, 5})); + // 2 missing chunks + set_pattern(write_buf+1*4096, 4096, 0); + set_pattern(write_buf+5*4096, 4096, 0); + res = ec_find_good(stripes, 6, 4, false, 4096, 0, 100); + assert_eq_vec(res, std::vector()); + // Done + free(rmw_buf); + free(write_buf); + use_ec(6, 4, false); +}