`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
+`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
+
**Response format**
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt;
+ std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+ params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
+ std::vector<std::string> response_fields;
slot_params generation_params;
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
- return res;
+ return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
}
json to_json_oaicompat_chat() {
res->tokens = slot.generated_tokens;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
+ res->response_fields = slot.params.response_fields;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
# assert match_regex(re_content, res.body["content"])
+@pytest.mark.parametrize(
+ "prompt,n_predict,response_fields",
+ [
+ ("I believe the meaning of life is", 8, []),
+ ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
+ ],
+)
+def test_completion_response_fields(
+ prompt: str, n_predict: int, response_fields: list[str]
+):
+ global server
+ server.start()
+ res = server.make_request(
+ "POST",
+ "/completion",
+ data={
+ "n_predict": n_predict,
+ "prompt": prompt,
+ "response_fields": response_fields,
+ },
+ )
+ assert res.status_code == 200
+ assert "content" in res.body
+ assert len(res.body["content"])
+ if len(response_fields):
+ assert res.body["generation_settings/n_predict"] == n_predict
+ assert res.body["prompt"] == "<s> " + prompt
+ assert isinstance(res.body["content"], str)
+ assert len(res.body) == len(response_fields)
+ else:
+ assert len(res.body)
+ assert "generation_settings" in res.body
+
+
def test_n_probs():
global server
server.start()
return false;
}
+// get value by path(key1 / key2)
+static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
+ json result = json::object();
+
+ for (const std::string & path : paths) {
+ json current = js;
+ const auto keys = string_split<std::string>(path, /*separator*/ '/');
+ bool valid_path = true;
+ for (const std::string & k : keys) {
+ if (valid_path && current.is_object() && current.contains(k)) {
+ current = current[k];
+ } else {
+ valid_path = false;
+ }
+ }
+ if (valid_path) {
+ result[path] = current;
+ }
+ }
+ return result;
+}
+
/**
* this handles 2 cases:
* - only string, example: "string"