shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
+shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
- bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
- uvec4 b = subgroupBallot(top);
- uint bit_count = subgroupBallotBitCount(b);
- if ((tid % SUBGROUP_SIZE) == 0) {
- offset_partials[tid / SUBGROUP_SIZE] = bit_count;
- }
- barrier();
+ // Values strictly greater than range_min must be stored. For values equal
+ // to range_min, there can be ties and it's possible we'll need to store
+ // an arbitrary subset of them.
+ // If total == p.k, have a fast path where we don't need to handle ties.
+ if (total == p.k) {
+ bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
+ uvec4 b = subgroupBallot(top);
+ uint bit_count = subgroupBallotBitCount(b);
+ if ((tid % SUBGROUP_SIZE) == 0) {
+ offset_partials[tid / SUBGROUP_SIZE] = bit_count;
+ }
+ barrier();
- uint out_idx = 0;
- [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
- if (i < tid / SUBGROUP_SIZE) {
- out_idx += offset_partials[i];
+ uint out_idx = 0;
+ [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+ if (i < tid / SUBGROUP_SIZE) {
+ out_idx += offset_partials[i];
+ }
}
- }
- uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
- if (top) {
- // TODO: Copy directly to the output?
- dst_row[out_idx + bit_count_ex] = v;
+ uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
+ if (top) {
+ // TODO: Copy directly to the output?
+ dst_row[out_idx + bit_count_ex] = v;
+ }
+ } else {
+ bool top = f2ui(intBitsToFloat(v.y)) > range_min;
+ bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
+ uvec4 b_top = subgroupBallot(top);
+ uvec4 b_eq_min = subgroupBallot(eq_min);
+ uint bit_count_top = subgroupBallotBitCount(b_top);
+ uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
+ if ((tid % SUBGROUP_SIZE) == 0) {
+ offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
+ eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
+ }
+ barrier();
+
+ uint out_idx = 0;
+ uint eq_min_base = 0;
+ uint eq_min_idx = 0;
+ [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+ if (i < tid / SUBGROUP_SIZE) {
+ out_idx += offset_partials[i];
+ eq_min_idx += eq_min_partials[i];
+ }
+ eq_min_base += offset_partials[i];
+ }
+ // range_min values are stored at the end
+ eq_min_idx += eq_min_base;
+
+ uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
+ uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
+ if (top) {
+ // TODO: Copy directly to the output?
+ dst_row[out_idx + bit_count_ex_top] = v;
+ }
+ if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
+ dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
+ }
}
barrier();
return mse_a_b / mse_a_0;
}
-// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
-static double jdst(const int32_t * a, const int32_t * b, size_t n) {
- std::unordered_map<int32_t, size_t> set_a;
- std::unordered_map<int32_t, size_t> set_b;
+// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
+template <typename T>
+static double jdst(const T * a, const T * b, size_t n) {
+ std::unordered_map<T, size_t> set_a;
+ std::unordered_map<T, size_t> set_b;
for (size_t i = 0; i < n; ++i) {
set_a[a[i]]++;
const ggml_type type;
const std::array<int64_t, 4> ne;
const int k;
+ const bool ties;
+ ggml_tensor * input {};
std::string vars() override {
- return VARS_TO_STR3(type, ne, k);
+ return VARS_TO_STR4(type, ne, k, ties);
}
test_top_k(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
- int k = 4)
- : type(type), ne(ne), k(k) {}
+ int k = 4, bool ties = false)
+ : type(type), ne(ne), k(k), ties(ties) {}
double max_err() override {
return 0.0;
}
+ // When there are ties, only validate the final result.
+ // The logic in err can't handle the sentinel tensors.
+ bool run_whole_graph() override { return ties; }
+
double err(const float * a, const float * b, size_t n) override {
- std::vector<int32_t> ia(n);
- std::vector<int32_t> ib(n);
+ // When there are no ties, we expect the exact same set of indices,
+ // but possibly in a different order. When there are ties, the indices
+ // can be different but the input values they correspond to should be
+ // the same. The logic for ties could work for non-ties, but only for
+ // the output tensor, not for the sentinel tensors.
+ if (ties) {
+ std::vector<float> src(ggml_nelements(input));
+
+ ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
+
+ double diff = 0.0f;
+
+ GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
+ int64_t cols = input->ne[0];
+ std::vector<int32_t> ia(k);
+ std::vector<int32_t> ib(k);
+ std::vector<float> asrc(k);
+ std::vector<float> bsrc(k);
+ for (int64_t r = 0; r < ggml_nrows(input); r++) {
+ // Convert indices for the row back to integer
+ for (int64_t c = 0; c < k; c++) {
+ ia[c] = (int32_t)a[r * k + c];
+ ib[c] = (int32_t)b[r * k + c];
+ }
+ // The src values for each row should match.
+ for (int64_t c = 0; c < k; c++) {
+ asrc[c] = src[r * cols + ia[c]];
+ bsrc[c] = src[r * cols + ib[c]];
+ }
+ diff += jdst(asrc.data(), bsrc.data(), k);
+ // There should be no duplicate indices
+ std::sort(ia.begin(), ia.end());
+ std::sort(ib.begin(), ib.end());
+ if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
+ diff += 1;
+ }
+ if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
+ diff += 1;
+ }
+ }
+ return diff;
+ } else {
+ std::vector<int32_t> ia(n);
+ std::vector<int32_t> ib(n);
- double diff = 0.0f;
+ double diff = 0.0f;
- for (size_t i = 0; i < n; i++) {
- ia[i] = (int32_t) a[i];
- ib[i] = (int32_t) b[i];
+ for (size_t i = 0; i < n; i++) {
+ ia[i] = (int32_t) a[i];
+ ib[i] = (int32_t) b[i];
- // penalize the result if the data is not integer valued
- diff += std::fabs(a[i] - ia[i]);
- diff += std::fabs(b[i] - ib[i]);
- }
+ // penalize the result if the data is not integer valued
+ diff += std::fabs(a[i] - ia[i]);
+ diff += std::fabs(b[i] - ib[i]);
+ }
- return diff + jdst(ia.data(), ib.data(), n);
+ return diff + jdst(ia.data(), ib.data(), n);
+ }
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
+ // Save 'a' for err()
+ input = a;
+
ggml_tensor * out = ggml_top_k(ctx, a, k);
ggml_set_name(out, "out");
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- // initialize with unique values to avoid ties
+ int tie_denom = std::max(1, std::min(10, k / 2));
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
- data[i] = i;
+ if (ties) {
+ // integer division to introduce duplicates
+ data[i] = i / tie_denom;
+ } else {
+ data[i] = i;
+ }
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
if (k <= 1<<i) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
+ test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
}
}
}