Make desktop AI gens asynchronous with threads

main
parent eb8948a24c
commit bcab5ad15e

302
main.c

@ -119,6 +119,7 @@ void web_arena_set_auto_align(WebArena *arena, size_t align)
#include "md.c"
#pragma warning(pop)
MD_Arena *persistent_arena = 0; // watch out, arenas have limited size.
#include <math.h>
@ -131,13 +132,6 @@ void web_arena_set_auto_align(WebArena *arena, size_t align)
#include "profiling.h"
#ifdef DESKTOP
#ifdef WINDOWS
#include <WinHttp.h>
#else
#error "Only know how to do desktop http requests on windows"
#endif // WINDOWS
#endif // DESKTOP
double clamp(double d, double min, double max)
@ -438,6 +432,166 @@ LPCWSTR windows_string(MD_String8 s)
}
#endif
#ifdef DESKTOP
#ifdef WINDOWS
#include <WinHttp.h>
#include <process.h>
typedef struct ChatRequest
{
struct ChatRequest *next;
struct ChatRequest *prev;
int id;
int status;
char generated[MAX_SENTENCE_LENGTH];
int generated_length;
uintptr_t thread_handle;
MD_Arena *arena;
MD_String8 post_req_body; // allocated on thread_arena
} ChatRequest;
ChatRequest *requests_first = 0;
ChatRequest *requests_last = 0;
int next_request_id = 1;
ChatRequest *requests_free_list = 0;
void generation_thread(void* my_request_voidptr)
{
ChatRequest *my_request = (ChatRequest*)my_request_voidptr;
bool succeeded = true;
#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); my_request->status = 2; return; }
HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
WinAssertWithErrorCode(hSession);
LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN));
HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0);
WinAssertWithErrorCode(hConnect);
int security_flags = 0;
if(IS_SERVER_SECURE)
{
security_flags = WINHTTP_FLAG_SECURE;
}
HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags);
WinAssertWithErrorCode(hRequest);
// @IMPORTANT @TODO the windows_string allocates on the frame arena, but
// according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest
// the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad?
succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)my_request->post_req_body.str, (DWORD)my_request->post_req_body.size, (DWORD)my_request->post_req_body.size, 0);
if(!succeeded)
{
Log("Couldn't do the web: %u\n", GetLastError());
my_request->status = 2;
}
if(succeeded)
{
WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0));
DWORD status_code;
DWORD status_code_size = sizeof(status_code);
WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX));
Log("Status code: %u\n", status_code);
DWORD dwSize = 0;
MD_String8List received_data_list = {0};
do
{
dwSize = 0;
WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize));
if(dwSize == 0)
{
Log("Didn't get anything back.\n");
}
else
{
MD_u8* out_buffer = MD_PushArray(my_request->arena, MD_u8, dwSize + 1);
DWORD dwDownloaded = 0;
WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded));
out_buffer[dwDownloaded - 1] = '\0';
Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer);
MD_S8ListPush(my_request->arena, &received_data_list, MD_S8(out_buffer, dwDownloaded));
}
} while (dwSize > 0);
MD_String8 received_data = MD_S8ListJoin(my_request->arena, received_data_list, &(MD_StringJoin){0});
MD_String8 ai_response = MD_S8Substring(received_data, 1, received_data.size);
if(ai_response.size > ARRLEN(my_request->generated))
{
Log("%lld too big for %lld\n", ai_response.size, ARRLEN(my_request->generated));
my_request->status = 2;
return;
}
memcpy(my_request->generated, ai_response.str, ai_response.size);
my_request->generated_length = (int)ai_response.size;
my_request->status = 1;
}
}
int make_generation_request(MD_String8 post_req_body)
{
ChatRequest *to_return = 0;
if(requests_free_list)
{
to_return = requests_free_list;
requests_free_list = requests_free_list->next;
//MD_StackPop(requests_free_list);
*to_return = (ChatRequest){0};
}
else
{
to_return = MD_PushArrayZero(persistent_arena, ChatRequest, 1);
}
to_return->arena = MD_ArenaAlloc();
to_return->id = next_request_id;
next_request_id += 1;
to_return->post_req_body.str = MD_PushArrayZero(to_return->arena, MD_u8, post_req_body.size);
to_return->post_req_body.size = post_req_body.size;
memcpy(to_return->post_req_body.str, post_req_body.str, post_req_body.size);
to_return->thread_handle = _beginthread(generation_thread, 0, to_return);
assert(to_return->thread_handle);
MD_DblPushBack(requests_first, requests_last, to_return);
return to_return->id;
}
// should never return null
// @TODO @IMPORTANT this doesn't work with save games because it assumes the id is always
// valid but saved IDs won't be valid on reboot
ChatRequest *get_by_id(int id)
{
for(ChatRequest *cur = requests_first; cur; cur = cur->next)
{
if(cur->id == id)
{
return cur;
}
}
assert(false);
return 0;
}
void done_with_request(int id)
{
ChatRequest *req = get_by_id(id);
MD_ArenaRelease(req->arena);
MD_DblRemove(requests_first, requests_last, req);
MD_StackPush(requests_free_list, req);
}
#else
#error "Only know how to do desktop http requests on windows"
#endif // WINDOWS
#endif // DESKTOP
MD_String8 tprint(char *format, ...)
{
MD_String8 to_return = {0};
@ -937,6 +1091,9 @@ bool perform_action(Entity *from, Action a)
bool propagate_to_party = from->is_character || (from->is_npc && from->standing == STANDING_JOINED);
if(action_target == player) propagate_to_party = true;
if(context.eavesdropped_from_party) propagate_to_party = false;
if(propagate_to_party)
@ -952,6 +1109,8 @@ bool perform_action(Entity *from, Action a)
}
}
// npcs in party when they talk should have their speech heard by who the player is talking to
if(from->is_npc && from->standing == STANDING_JOINED)
{
if(gete(player->talking_to) && gete(player->talking_to) != from)
@ -1219,6 +1378,7 @@ void init(void)
#endif
frame_arena = MD_ArenaAlloc();
persistent_arena = MD_ArenaAlloc();
Log("Size of entity struct: %zu\n", sizeof(Entity));
Log("Size of %d gs.entities: %zu kb\n", (int)ARRLEN(gs.entities), sizeof(gs.entities) / 1024);
@ -2954,15 +3114,23 @@ void frame(void)
ENTITIES_ITER(gs.entities)
{
assert(!(it->exists && it->generation == 0));
#ifdef WEB
if (it->is_npc)
{
if (it->gen_request_id != 0)
{
assert(it->gen_request_id > 0);
#ifdef DESKTOP
int status = get_by_id(it->gen_request_id)->status;
#else
#ifdef WEB
int status = EM_ASM_INT( {
return get_generation_request_status($0);
}, it->gen_request_id);
#else
#error "Don't know how to do this stuff on this platform."
#endif // WEB
#endif // DESKTOP
if (status == 0)
{
// simply not done yet
@ -2973,10 +3141,16 @@ void frame(void)
{
// done! we can get the string
char sentence_cstr[MAX_SENTENCE_LENGTH] = { 0 };
#ifdef WEB
EM_ASM( {
let generation = get_generation_request_content($0);
stringToUTF8(generation, $1, $2);
}, it->gen_request_id, sentence_cstr, ARRLEN(sentence_cstr) - 1); // I think minus one for null terminator...
#endif
#ifdef DESKTOP
memcpy(sentence_cstr, get_by_id(it->gen_request_id)->generated, get_by_id(it->gen_request_id)->generated_length);
#endif
MD_String8 sentence_str = MD_S8CString(sentence_cstr);
@ -2997,9 +3171,14 @@ void frame(void)
MD_ReleaseScratch(scratch);
#ifdef WEB
EM_ASM( {
done_with_generation_request($0);
}, it->gen_request_id);
#endif
#ifdef DESKTOP
done_with_request(it->gen_request_id);
#endif
}
else if (status == 2)
{
@ -3008,7 +3187,7 @@ void frame(void)
Action to_perform = {0};
MD_String8 speech_mdstring = MD_S8Lit("I'm not sure...");
memcpy(to_perform.speech, speech_mdstring.str, speech_mdstring.size);
to_perform.speech_length = speech_mdstring.size;
to_perform.speech_length = (int)speech_mdstring.size;
perform_action(it, to_perform);
}
else if (status == -1)
@ -3023,7 +3202,6 @@ void frame(void)
}
}
}
#endif
if (fabsf(it->vel.x) > 0.01f)
@ -3578,99 +3756,39 @@ void frame(void)
else
{
MD_String8 post_request_body = MD_S8Fmt(scratch.arena, "|%.*s", MD_S8VArg(prompt_str));
it->gen_request_id = make_generation_request(post_request_body);
}
#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); assert(false); }
HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
WinAssertWithErrorCode(hSession);
LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN));
HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0);
WinAssertWithErrorCode(hConnect);
int security_flags = 0;
if(IS_SERVER_SECURE)
// something to mock
if(ai_response.size > 0)
{
Log("Mocking...\n");
Action a = {0};
MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__));
if(succeeded)
{
security_flags = WINHTTP_FLAG_SECURE;
error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a);
}
HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags);
WinAssertWithErrorCode(hRequest);
// @IMPORTANT @TODO the windows_string allocates on the frame arena, but
// according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest
// the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad?
succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)post_request_body.str, (DWORD)post_request_body.size, (DWORD)post_request_body.size, 0);
if(!succeeded)
if(mocking_the_ai_response)
{
Log("Couldn't do the web: %u\n", GetLastError());
assert(succeeded);
assert(error_message.size == 0);
perform_action(it, a);
}
if(succeeded)
else
{
WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0));
DWORD status_code;
DWORD status_code_size = sizeof(status_code);
WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX));
Log("Status code: %u\n", status_code);
DWORD dwSize = 0;
MD_String8List received_data_list = {0};
do
if(succeeded)
{
dwSize = 0;
WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize));
if(dwSize == 0)
if (error_message.size == 0)
{
Log("Didn't get anything back.\n");
perform_action(it, a);
}
else
{
MD_u8* out_buffer = MD_PushArray(scratch.arena, MD_u8, dwSize + 1);
DWORD dwDownloaded = 0;
WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded));
out_buffer[dwDownloaded - 1] = '\0';
Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer);
MD_S8ListPush(scratch.arena, &received_data_list, MD_S8(out_buffer, dwDownloaded));
Log("There was an error with the AI: %.*s", MD_S8VArg(error_message));
remember_error(it, error_message);
}
} while (dwSize > 0);
MD_String8 received_data = MD_S8ListJoin(scratch.arena, received_data_list, &(MD_StringJoin){0});
ai_response = MD_S8Substring(received_data, 1, received_data.size);
}
else
{
it->perceptions_dirty = true;
}
}
Action a = {0};
MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__));
if(succeeded)
{
error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a);
}
if(mocking_the_ai_response)
{
assert(succeeded);
assert(error_message.size == 0);
perform_action(it, a);
}
else
{
if(succeeded)
{
if (error_message.size == 0)
{
perform_action(it, a);
}
else
{
Log("There was an error with the AI: %.*s", MD_S8VArg(error_message));
remember_error(it, error_message);
}
}
}
@ -3919,12 +4037,10 @@ void frame(void)
PROFILE_SCOPE("entity rendering")
ENTITIES_ITER(gs.entities)
{
#ifdef WEB
if (it->gen_request_id != 0)
{
draw_quad((DrawParams) { true, quad_centered(AddV2(it->pos, V2(0.0, 50.0)), V2(100.0, 100.0)), IMG(image_thinking), WHITE });
}
#endif
Color col = LerpV4(WHITE, it->damage, RED);
if (it->is_npc)
@ -4540,6 +4656,8 @@ void frame(void)
void cleanup(void)
{
free(fontBuffer);
MD_ArenaRelease(frame_arena);
MD_ArenaRelease(persistent_arena);
sg_shutdown();
hmfree(imui_state);
Log("Cleaning up\n");

@ -236,9 +236,7 @@ typedef struct Entity
NPCPlayerStanding standing;
NpcKind npc_kind;
PathCacheHandle cached_path;
#ifdef WEB
int gen_request_id;
#endif
bool walking;
double shotgun_timer;
bool moved;

17
thirdparty/md.h vendored

@ -362,12 +362,13 @@
(zchk(f)?\
((f)=(l)=(n),zset((n)->next),zset((n)->prev)):\
((n)->prev=(l),(l)->next=(n),(l)=(n),zset((n)->next)))
#define MD_DblRemove_NPZ(f,l,n,next,prev,zset) (((f)==(n)?\
((f)=(f)->next,zset((f)->prev)):\
(l)==(n)?\
((l)=(l)->prev,zset((l)->next)):\
((n)->next->prev=(n)->prev,\
(n)->prev->next=(n)->next)))
#define MD_DblRemove_NPZ(f,l,n,next,prev,zchk,zset) (((f)==(n))?\
((f)=(f)->next, (zchk(f) ? (zset(l)) : zset((f)->prev))):\
((l)==(n))?\
((l)=(l)->prev, (zchk(l) ? (zset(f)) : zset((l)->next))):\
((zchk((n)->next) ? (0) : ((n)->next->prev=(n)->prev)),\
(zchk((n)->prev) ? (0) : ((n)->prev->next=(n)->next))))
// compositions
#define MD_QueuePush(f,l,n) MD_QueuePush_NZ(f,l,n,next,MD_CheckNull,MD_SetNull)
@ -376,11 +377,11 @@
#define MD_StackPop(f) MD_StackPop_NZ(f,next,MD_CheckNull)
#define MD_DblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull)
#define MD_DblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNull,MD_SetNull)
#define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNull)
#define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull)
#define MD_NodeDblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil)
#define MD_NodeDblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNil,MD_SetNil)
#define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNil)
#define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil)
//~ Memory Operations

Loading…
Cancel
Save