From bcab5ad15eabccb1f050d31ff3ddc2bfecaf9297 Mon Sep 17 00:00:00 2001 From: Cameron Reikes Date: Mon, 29 May 2023 03:38:05 -0700 Subject: [PATCH] Make desktop AI gens asynchronous with threads --- main.c | 302 +++++++++++++++++++++++++++++++++--------------- makeprompt.h | 2 - thirdparty/md.h | 17 +-- 3 files changed, 219 insertions(+), 102 deletions(-) diff --git a/main.c b/main.c index 9fadc2e..0c37eb8 100644 --- a/main.c +++ b/main.c @@ -119,6 +119,7 @@ void web_arena_set_auto_align(WebArena *arena, size_t align) #include "md.c" #pragma warning(pop) +MD_Arena *persistent_arena = 0; // watch out, arenas have limited size. #include @@ -131,13 +132,6 @@ void web_arena_set_auto_align(WebArena *arena, size_t align) #include "profiling.h" -#ifdef DESKTOP -#ifdef WINDOWS -#include -#else -#error "Only know how to do desktop http requests on windows" -#endif // WINDOWS -#endif // DESKTOP double clamp(double d, double min, double max) @@ -438,6 +432,166 @@ LPCWSTR windows_string(MD_String8 s) } #endif +#ifdef DESKTOP +#ifdef WINDOWS +#include +#include + +typedef struct ChatRequest +{ + struct ChatRequest *next; + struct ChatRequest *prev; + int id; + int status; + char generated[MAX_SENTENCE_LENGTH]; + int generated_length; + uintptr_t thread_handle; + MD_Arena *arena; + MD_String8 post_req_body; // allocated on thread_arena +} ChatRequest; + +ChatRequest *requests_first = 0; +ChatRequest *requests_last = 0; + +int next_request_id = 1; +ChatRequest *requests_free_list = 0; + +void generation_thread(void* my_request_voidptr) +{ + ChatRequest *my_request = (ChatRequest*)my_request_voidptr; + + bool succeeded = true; + +#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); my_request->status = 2; return; } + + HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0); + WinAssertWithErrorCode(hSession); + + LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN)); + HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0); + WinAssertWithErrorCode(hConnect); + int security_flags = 0; + if(IS_SERVER_SECURE) + { + security_flags = WINHTTP_FLAG_SECURE; + } + + HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags); + WinAssertWithErrorCode(hRequest); + + // @IMPORTANT @TODO the windows_string allocates on the frame arena, but + // according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest + // the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad? + succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)my_request->post_req_body.str, (DWORD)my_request->post_req_body.size, (DWORD)my_request->post_req_body.size, 0); + if(!succeeded) + { + Log("Couldn't do the web: %u\n", GetLastError()); + my_request->status = 2; + } + if(succeeded) + { + WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0)); + + DWORD status_code; + DWORD status_code_size = sizeof(status_code); + WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX)); + Log("Status code: %u\n", status_code); + + DWORD dwSize = 0; + MD_String8List received_data_list = {0}; + do + { + dwSize = 0; + WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize)); + + if(dwSize == 0) + { + Log("Didn't get anything back.\n"); + } + else + { + MD_u8* out_buffer = MD_PushArray(my_request->arena, MD_u8, dwSize + 1); + DWORD dwDownloaded = 0; + WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded)); + out_buffer[dwDownloaded - 1] = '\0'; + Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer); + MD_S8ListPush(my_request->arena, &received_data_list, MD_S8(out_buffer, dwDownloaded)); + } + } while (dwSize > 0); + MD_String8 received_data = MD_S8ListJoin(my_request->arena, received_data_list, &(MD_StringJoin){0}); + + MD_String8 ai_response = MD_S8Substring(received_data, 1, received_data.size); + if(ai_response.size > ARRLEN(my_request->generated)) + { + Log("%lld too big for %lld\n", ai_response.size, ARRLEN(my_request->generated)); + my_request->status = 2; + return; + } + memcpy(my_request->generated, ai_response.str, ai_response.size); + my_request->generated_length = (int)ai_response.size; + my_request->status = 1; + } +} + +int make_generation_request(MD_String8 post_req_body) +{ + ChatRequest *to_return = 0; + if(requests_free_list) + { + to_return = requests_free_list; + requests_free_list = requests_free_list->next; + //MD_StackPop(requests_free_list); + *to_return = (ChatRequest){0}; + } + else + { + to_return = MD_PushArrayZero(persistent_arena, ChatRequest, 1); + } + to_return->arena = MD_ArenaAlloc(); + to_return->id = next_request_id; + next_request_id += 1; + + to_return->post_req_body.str = MD_PushArrayZero(to_return->arena, MD_u8, post_req_body.size); + to_return->post_req_body.size = post_req_body.size; + memcpy(to_return->post_req_body.str, post_req_body.str, post_req_body.size); + + to_return->thread_handle = _beginthread(generation_thread, 0, to_return); + assert(to_return->thread_handle); + + MD_DblPushBack(requests_first, requests_last, to_return); + + return to_return->id; +} + +// should never return null +// @TODO @IMPORTANT this doesn't work with save games because it assumes the id is always +// valid but saved IDs won't be valid on reboot +ChatRequest *get_by_id(int id) +{ + for(ChatRequest *cur = requests_first; cur; cur = cur->next) + { + if(cur->id == id) + { + return cur; + } + } + assert(false); + return 0; +} + +void done_with_request(int id) +{ + ChatRequest *req = get_by_id(id); + MD_ArenaRelease(req->arena); + MD_DblRemove(requests_first, requests_last, req); + MD_StackPush(requests_free_list, req); +} + +#else +#error "Only know how to do desktop http requests on windows" +#endif // WINDOWS +#endif // DESKTOP + MD_String8 tprint(char *format, ...) { MD_String8 to_return = {0}; @@ -937,6 +1091,9 @@ bool perform_action(Entity *from, Action a) bool propagate_to_party = from->is_character || (from->is_npc && from->standing == STANDING_JOINED); + if(action_target == player) propagate_to_party = true; + + if(context.eavesdropped_from_party) propagate_to_party = false; if(propagate_to_party) @@ -952,6 +1109,8 @@ bool perform_action(Entity *from, Action a) } } + + // npcs in party when they talk should have their speech heard by who the player is talking to if(from->is_npc && from->standing == STANDING_JOINED) { if(gete(player->talking_to) && gete(player->talking_to) != from) @@ -1219,6 +1378,7 @@ void init(void) #endif frame_arena = MD_ArenaAlloc(); + persistent_arena = MD_ArenaAlloc(); Log("Size of entity struct: %zu\n", sizeof(Entity)); Log("Size of %d gs.entities: %zu kb\n", (int)ARRLEN(gs.entities), sizeof(gs.entities) / 1024); @@ -2954,15 +3114,23 @@ void frame(void) ENTITIES_ITER(gs.entities) { assert(!(it->exists && it->generation == 0)); -#ifdef WEB if (it->is_npc) { if (it->gen_request_id != 0) { assert(it->gen_request_id > 0); + +#ifdef DESKTOP + int status = get_by_id(it->gen_request_id)->status; +#else +#ifdef WEB int status = EM_ASM_INT( { return get_generation_request_status($0); }, it->gen_request_id); +#else +#error "Don't know how to do this stuff on this platform." +#endif // WEB +#endif // DESKTOP if (status == 0) { // simply not done yet @@ -2973,10 +3141,16 @@ void frame(void) { // done! we can get the string char sentence_cstr[MAX_SENTENCE_LENGTH] = { 0 }; +#ifdef WEB EM_ASM( { let generation = get_generation_request_content($0); stringToUTF8(generation, $1, $2); }, it->gen_request_id, sentence_cstr, ARRLEN(sentence_cstr) - 1); // I think minus one for null terminator... +#endif + +#ifdef DESKTOP + memcpy(sentence_cstr, get_by_id(it->gen_request_id)->generated, get_by_id(it->gen_request_id)->generated_length); +#endif MD_String8 sentence_str = MD_S8CString(sentence_cstr); @@ -2997,9 +3171,14 @@ void frame(void) MD_ReleaseScratch(scratch); +#ifdef WEB EM_ASM( { done_with_generation_request($0); }, it->gen_request_id); +#endif +#ifdef DESKTOP + done_with_request(it->gen_request_id); +#endif } else if (status == 2) { @@ -3008,7 +3187,7 @@ void frame(void) Action to_perform = {0}; MD_String8 speech_mdstring = MD_S8Lit("I'm not sure..."); memcpy(to_perform.speech, speech_mdstring.str, speech_mdstring.size); - to_perform.speech_length = speech_mdstring.size; + to_perform.speech_length = (int)speech_mdstring.size; perform_action(it, to_perform); } else if (status == -1) @@ -3023,7 +3202,6 @@ void frame(void) } } } -#endif if (fabsf(it->vel.x) > 0.01f) @@ -3578,99 +3756,39 @@ void frame(void) else { MD_String8 post_request_body = MD_S8Fmt(scratch.arena, "|%.*s", MD_S8VArg(prompt_str)); + it->gen_request_id = make_generation_request(post_request_body); + } -#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); assert(false); } - - - - HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0); - WinAssertWithErrorCode(hSession); - - LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN)); - HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0); - WinAssertWithErrorCode(hConnect); - int security_flags = 0; - if(IS_SERVER_SECURE) + // something to mock + if(ai_response.size > 0) + { + Log("Mocking...\n"); + Action a = {0}; + MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__)); + if(succeeded) { - security_flags = WINHTTP_FLAG_SECURE; + error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a); } - HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags); - WinAssertWithErrorCode(hRequest); - - // @IMPORTANT @TODO the windows_string allocates on the frame arena, but - // according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest - // the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad? - succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)post_request_body.str, (DWORD)post_request_body.size, (DWORD)post_request_body.size, 0); - if(!succeeded) + if(mocking_the_ai_response) { - Log("Couldn't do the web: %u\n", GetLastError()); + assert(succeeded); + assert(error_message.size == 0); + perform_action(it, a); } - if(succeeded) + else { - WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0)); - - DWORD status_code; - DWORD status_code_size = sizeof(status_code); - WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX)); - Log("Status code: %u\n", status_code); - - DWORD dwSize = 0; - MD_String8List received_data_list = {0}; - do + if(succeeded) { - dwSize = 0; - WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize)); - - if(dwSize == 0) + if (error_message.size == 0) { - Log("Didn't get anything back.\n"); + perform_action(it, a); } else { - MD_u8* out_buffer = MD_PushArray(scratch.arena, MD_u8, dwSize + 1); - DWORD dwDownloaded = 0; - WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded)); - out_buffer[dwDownloaded - 1] = '\0'; - Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer); - MD_S8ListPush(scratch.arena, &received_data_list, MD_S8(out_buffer, dwDownloaded)); + Log("There was an error with the AI: %.*s", MD_S8VArg(error_message)); + remember_error(it, error_message); } - } while (dwSize > 0); - MD_String8 received_data = MD_S8ListJoin(scratch.arena, received_data_list, &(MD_StringJoin){0}); - - ai_response = MD_S8Substring(received_data, 1, received_data.size); - } - else - { - it->perceptions_dirty = true; - } - } - - Action a = {0}; - MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__)); - if(succeeded) - { - error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a); - } - - if(mocking_the_ai_response) - { - assert(succeeded); - assert(error_message.size == 0); - perform_action(it, a); - } - else - { - if(succeeded) - { - if (error_message.size == 0) - { - perform_action(it, a); - } - else - { - Log("There was an error with the AI: %.*s", MD_S8VArg(error_message)); - remember_error(it, error_message); } } } @@ -3919,12 +4037,10 @@ void frame(void) PROFILE_SCOPE("entity rendering") ENTITIES_ITER(gs.entities) { -#ifdef WEB if (it->gen_request_id != 0) { draw_quad((DrawParams) { true, quad_centered(AddV2(it->pos, V2(0.0, 50.0)), V2(100.0, 100.0)), IMG(image_thinking), WHITE }); } -#endif Color col = LerpV4(WHITE, it->damage, RED); if (it->is_npc) @@ -4540,6 +4656,8 @@ void frame(void) void cleanup(void) { free(fontBuffer); + MD_ArenaRelease(frame_arena); + MD_ArenaRelease(persistent_arena); sg_shutdown(); hmfree(imui_state); Log("Cleaning up\n"); diff --git a/makeprompt.h b/makeprompt.h index 8d0f321..33ff472 100644 --- a/makeprompt.h +++ b/makeprompt.h @@ -236,9 +236,7 @@ typedef struct Entity NPCPlayerStanding standing; NpcKind npc_kind; PathCacheHandle cached_path; -#ifdef WEB int gen_request_id; -#endif bool walking; double shotgun_timer; bool moved; diff --git a/thirdparty/md.h b/thirdparty/md.h index a7da6f5..791685d 100644 --- a/thirdparty/md.h +++ b/thirdparty/md.h @@ -362,12 +362,13 @@ (zchk(f)?\ ((f)=(l)=(n),zset((n)->next),zset((n)->prev)):\ ((n)->prev=(l),(l)->next=(n),(l)=(n),zset((n)->next))) -#define MD_DblRemove_NPZ(f,l,n,next,prev,zset) (((f)==(n)?\ -((f)=(f)->next,zset((f)->prev)):\ -(l)==(n)?\ -((l)=(l)->prev,zset((l)->next)):\ -((n)->next->prev=(n)->prev,\ -(n)->prev->next=(n)->next))) +#define MD_DblRemove_NPZ(f,l,n,next,prev,zchk,zset) (((f)==(n))?\ +((f)=(f)->next, (zchk(f) ? (zset(l)) : zset((f)->prev))):\ +((l)==(n))?\ +((l)=(l)->prev, (zchk(l) ? (zset(f)) : zset((l)->next))):\ +((zchk((n)->next) ? (0) : ((n)->next->prev=(n)->prev)),\ +(zchk((n)->prev) ? (0) : ((n)->prev->next=(n)->next)))) + // compositions #define MD_QueuePush(f,l,n) MD_QueuePush_NZ(f,l,n,next,MD_CheckNull,MD_SetNull) @@ -376,11 +377,11 @@ #define MD_StackPop(f) MD_StackPop_NZ(f,next,MD_CheckNull) #define MD_DblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull) #define MD_DblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNull,MD_SetNull) -#define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNull) +#define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull) #define MD_NodeDblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil) #define MD_NodeDblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNil,MD_SetNil) -#define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNil) +#define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil) //~ Memory Operations