]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
dolly : disable interactive_port on Windows (#339)
authorBorislav Stanimirov <redacted>
Tue, 4 Jul 2023 13:26:29 +0000 (16:26 +0300)
committerGitHub <redacted>
Tue, 4 Jul 2023 13:26:29 +0000 (16:26 +0300)
examples/dolly-v2/main.cpp

index 9bc5e1a799b38c5e8ac59a1e93a12053301c1205..b95c5b1c6a11d2bd6983c737b83cb7d588f318a7 100644 (file)
 #include <string>
 #include <vector>
 
+#if !defined(_WIN32)
+#define DOLLY_INTERACTIVE_PORT
+#endif
+
+#if defined(DOLLY_INTERACTIVE_PORT)
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <sys/socket.h>
 #include <unistd.h>
+#endif
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -775,6 +781,7 @@ std::string execute_prompt(
     return output;
 }
 
+#if defined(DOLLY_INTERACTIVE_PORT)
 int setup_port(const int port) {
     int sockfd = socket(AF_INET, SOCK_STREAM, 0);
     if (sockfd < 0) {
@@ -818,6 +825,7 @@ std::string read_from_port(int sockfd, int clientfd) {
     }
     return std::string("");
 }
+#endif
 
 int main(int argc, char ** argv) {
     ggml_time_init();
@@ -865,6 +873,7 @@ int main(int argc, char ** argv) {
         test_gpt_tokenizer(vocab, params.token_test);
     }
 
+#if defined(DOLLY_INTERACTIVE_PORT)
     int sockfd;
     if (params.interactive_port != -1) {
         sockfd = setup_port(params.interactive_port);
@@ -874,17 +883,21 @@ int main(int argc, char ** argv) {
         fprintf(stdout, "Model is ready on port %i\n", params.interactive_port);
         fflush(stdout);
     }
+#endif
 
-    if (params.interactive or params.interactive_port != -1) {
+    if (params.interactive || params.interactive_port != -1) {
         while (true) {
             std::string prompt_input;
+#if defined(DOLLY_INTERACTIVE_PORT)
             int clientfd;
             if (params.interactive_port != -1) {
                 sockaddr_in clientaddr;
                 socklen_t clientaddrlen = sizeof(clientaddr);
-                clientfd = accept(sockfd, (struct sockaddr *)&clientaddr, &clientaddrlen);   
+                clientfd = accept(sockfd, (struct sockaddr *)&clientaddr, &clientaddrlen);
                 prompt_input = read_from_port(sockfd, clientfd);
-            } else {
+            } else
+#endif
+            {
                 printf("Please enter your quesiton:\n>");
                 fflush(stdout);
 
@@ -899,6 +912,7 @@ int main(int argc, char ** argv) {
             // call the model
             const std::string response = execute_prompt(model, vocab, prompt, params, rng, t_load_us, t_sample_us, t_predict_us, mem_per_token, n_past, true);
 
+#if defined(DOLLY_INTERACTIVE_PORT)
             if (params.interactive_port != -1) {
                 if (write(clientfd, response.c_str(), response.size()) < 0) {
                     std::cerr << "Failed to write to client\n";
@@ -907,8 +921,9 @@ int main(int argc, char ** argv) {
                 if (close(clientfd) < 0) {
                     std::cerr << "Failed to close client socket\n";
                 }
-            }
-            else {
+            } else
+#endif
+            {
                 printf("%s\n\n", response.c_str());
             }
             fflush(stdout);
@@ -936,9 +951,11 @@ int main(int argc, char ** argv) {
 
     ggml_free(model.ctx);
 
+#if defined(DOLLY_INTERACTIVE_PORT)
     if (params.interactive_port != -1 && close(sockfd) < 0) {
         std::cerr << "Failed to close server socket\n";
     }
+#endif
 
     return 0;
 }