sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
- tester.apply(llama_sampler_init_top_p(p, 1));
+ tester.apply(llama_sampler_init_top_p(p, 0));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
- tester.apply(llama_sampler_init_min_p(p, 1));
+ tester.apply(llama_sampler_init_min_p(p, 0));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
- tester.apply(llama_sampler_init_typical(p, 1));
+ tester.apply(llama_sampler_init_typical(p, 0));
DUMP(&tester.cur_p);
tester.check();
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.05f);
printf("XTC should:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
printf("XTC should not:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
- test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
- test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
+ test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
+ test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);