const int n; // cols
const int m; // rows
const int r; // rows to get
- const int b; // batch size
+ const int be1; // batch size
+ const int be2; // batch size
const bool v; // view (non-contiguous src1)
std::string vars() override {
- return VARS_TO_STR6(type, n, m, r, b, v);
+ return VARS_TO_STR7(type, n, m, r, be1, be2, v);
}
- test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
- : type(type), n(n), m(m), r(r), b(b), v(v) {}
+ test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int be1 = 1, int be2 = 1, bool v = false)
+ : type(type), n(n), m(m), r(r), be1(be1), be2(be2), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+ ggml_tensor * in = ggml_new_tensor_4d(ctx, type, n, m, be1, be2);
ggml_set_name(in, "in");
- ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+ ggml_tensor * rows = ggml_new_tensor_3d(ctx, GGML_TYPE_I32, r, be1, be2);
ggml_set_name(rows, "rows");
if (v) {
- rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+ rows = ggml_view_3d(ctx, rows, r/2, be1, be2, rows->nb[1], rows->nb[2], 0);
ggml_set_name(rows, "view_of_rows");
}
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// rows
- std::vector<int> data(r*b);
- for (int i = 0; i < r*b; i++) {
+ std::vector<int> data(r*be1*be2);
+ for (int i = 0; i < r*be1*be2; i++) {
data[i] = rand() % m;
}
- ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
+ ggml_backend_tensor_set(t, data.data(), 0, r * be1 * be2 * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
- test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) {
+ test_cases.emplace_back(new test_get_rows(type, 300*256, 5, 4, 1, 2, false));
+ test_cases.emplace_back(new test_get_rows(type, 256, 80000, 70000, 2, 1, false));
+ test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, 700, 100, false));
+ }
+
+ test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
for (bool v : {false, true}) {
- test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+ test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, 1, v));
}
}
}
for (int b : {1, 7}) {
for (bool v : {false, true}) {
- test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v));
+ test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, 1, v));
}
}