Make desktop AI gens asynchronous with threads

main
parent eb8948a24c
commit bcab5ad15e

302
main.c

@ -119,6 +119,7 @@ void web_arena_set_auto_align(WebArena *arena, size_t align)
#include "md.c" #include "md.c"
#pragma warning(pop) #pragma warning(pop)
MD_Arena *persistent_arena = 0; // watch out, arenas have limited size.
#include <math.h> #include <math.h>
@ -131,13 +132,6 @@ void web_arena_set_auto_align(WebArena *arena, size_t align)
#include "profiling.h" #include "profiling.h"
#ifdef DESKTOP
#ifdef WINDOWS
#include <WinHttp.h>
#else
#error "Only know how to do desktop http requests on windows"
#endif // WINDOWS
#endif // DESKTOP
double clamp(double d, double min, double max) double clamp(double d, double min, double max)
@ -438,6 +432,166 @@ LPCWSTR windows_string(MD_String8 s)
} }
#endif #endif
#ifdef DESKTOP
#ifdef WINDOWS
#include <WinHttp.h>
#include <process.h>
typedef struct ChatRequest
{
struct ChatRequest *next;
struct ChatRequest *prev;
int id;
int status;
char generated[MAX_SENTENCE_LENGTH];
int generated_length;
uintptr_t thread_handle;
MD_Arena *arena;
MD_String8 post_req_body; // allocated on thread_arena
} ChatRequest;
ChatRequest *requests_first = 0;
ChatRequest *requests_last = 0;
int next_request_id = 1;
ChatRequest *requests_free_list = 0;
void generation_thread(void* my_request_voidptr)
{
ChatRequest *my_request = (ChatRequest*)my_request_voidptr;
bool succeeded = true;
#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); my_request->status = 2; return; }
HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
WinAssertWithErrorCode(hSession);
LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN));
HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0);
WinAssertWithErrorCode(hConnect);
int security_flags = 0;
if(IS_SERVER_SECURE)
{
security_flags = WINHTTP_FLAG_SECURE;
}
HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags);
WinAssertWithErrorCode(hRequest);
// @IMPORTANT @TODO the windows_string allocates on the frame arena, but
// according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest
// the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad?
succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)my_request->post_req_body.str, (DWORD)my_request->post_req_body.size, (DWORD)my_request->post_req_body.size, 0);
if(!succeeded)
{
Log("Couldn't do the web: %u\n", GetLastError());
my_request->status = 2;
}
if(succeeded)
{
WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0));
DWORD status_code;
DWORD status_code_size = sizeof(status_code);
WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX));
Log("Status code: %u\n", status_code);
DWORD dwSize = 0;
MD_String8List received_data_list = {0};
do
{
dwSize = 0;
WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize));
if(dwSize == 0)
{
Log("Didn't get anything back.\n");
}
else
{
MD_u8* out_buffer = MD_PushArray(my_request->arena, MD_u8, dwSize + 1);
DWORD dwDownloaded = 0;
WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded));
out_buffer[dwDownloaded - 1] = '\0';
Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer);
MD_S8ListPush(my_request->arena, &received_data_list, MD_S8(out_buffer, dwDownloaded));
}
} while (dwSize > 0);
MD_String8 received_data = MD_S8ListJoin(my_request->arena, received_data_list, &(MD_StringJoin){0});
MD_String8 ai_response = MD_S8Substring(received_data, 1, received_data.size);
if(ai_response.size > ARRLEN(my_request->generated))
{
Log("%lld too big for %lld\n", ai_response.size, ARRLEN(my_request->generated));
my_request->status = 2;
return;
}
memcpy(my_request->generated, ai_response.str, ai_response.size);
my_request->generated_length = (int)ai_response.size;
my_request->status = 1;
}
}
int make_generation_request(MD_String8 post_req_body)
{
ChatRequest *to_return = 0;
if(requests_free_list)
{
to_return = requests_free_list;
requests_free_list = requests_free_list->next;
//MD_StackPop(requests_free_list);
*to_return = (ChatRequest){0};
}
else
{
to_return = MD_PushArrayZero(persistent_arena, ChatRequest, 1);
}
to_return->arena = MD_ArenaAlloc();
to_return->id = next_request_id;
next_request_id += 1;
to_return->post_req_body.str = MD_PushArrayZero(to_return->arena, MD_u8, post_req_body.size);
to_return->post_req_body.size = post_req_body.size;
memcpy(to_return->post_req_body.str, post_req_body.str, post_req_body.size);
to_return->thread_handle = _beginthread(generation_thread, 0, to_return);
assert(to_return->thread_handle);
MD_DblPushBack(requests_first, requests_last, to_return);
return to_return->id;
}
// should never return null
// @TODO @IMPORTANT this doesn't work with save games because it assumes the id is always
// valid but saved IDs won't be valid on reboot
ChatRequest *get_by_id(int id)
{
for(ChatRequest *cur = requests_first; cur; cur = cur->next)
{
if(cur->id == id)
{
return cur;
}
}
assert(false);
return 0;
}
void done_with_request(int id)
{
ChatRequest *req = get_by_id(id);
MD_ArenaRelease(req->arena);
MD_DblRemove(requests_first, requests_last, req);
MD_StackPush(requests_free_list, req);
}
#else
#error "Only know how to do desktop http requests on windows"
#endif // WINDOWS
#endif // DESKTOP
MD_String8 tprint(char *format, ...) MD_String8 tprint(char *format, ...)
{ {
MD_String8 to_return = {0}; MD_String8 to_return = {0};
@ -937,6 +1091,9 @@ bool perform_action(Entity *from, Action a)
bool propagate_to_party = from->is_character || (from->is_npc && from->standing == STANDING_JOINED); bool propagate_to_party = from->is_character || (from->is_npc && from->standing == STANDING_JOINED);
if(action_target == player) propagate_to_party = true;
if(context.eavesdropped_from_party) propagate_to_party = false; if(context.eavesdropped_from_party) propagate_to_party = false;
if(propagate_to_party) if(propagate_to_party)
@ -952,6 +1109,8 @@ bool perform_action(Entity *from, Action a)
} }
} }
// npcs in party when they talk should have their speech heard by who the player is talking to
if(from->is_npc && from->standing == STANDING_JOINED) if(from->is_npc && from->standing == STANDING_JOINED)
{ {
if(gete(player->talking_to) && gete(player->talking_to) != from) if(gete(player->talking_to) && gete(player->talking_to) != from)
@ -1219,6 +1378,7 @@ void init(void)
#endif #endif
frame_arena = MD_ArenaAlloc(); frame_arena = MD_ArenaAlloc();
persistent_arena = MD_ArenaAlloc();
Log("Size of entity struct: %zu\n", sizeof(Entity)); Log("Size of entity struct: %zu\n", sizeof(Entity));
Log("Size of %d gs.entities: %zu kb\n", (int)ARRLEN(gs.entities), sizeof(gs.entities) / 1024); Log("Size of %d gs.entities: %zu kb\n", (int)ARRLEN(gs.entities), sizeof(gs.entities) / 1024);
@ -2954,15 +3114,23 @@ void frame(void)
ENTITIES_ITER(gs.entities) ENTITIES_ITER(gs.entities)
{ {
assert(!(it->exists && it->generation == 0)); assert(!(it->exists && it->generation == 0));
#ifdef WEB
if (it->is_npc) if (it->is_npc)
{ {
if (it->gen_request_id != 0) if (it->gen_request_id != 0)
{ {
assert(it->gen_request_id > 0); assert(it->gen_request_id > 0);
#ifdef DESKTOP
int status = get_by_id(it->gen_request_id)->status;
#else
#ifdef WEB
int status = EM_ASM_INT( { int status = EM_ASM_INT( {
return get_generation_request_status($0); return get_generation_request_status($0);
}, it->gen_request_id); }, it->gen_request_id);
#else
#error "Don't know how to do this stuff on this platform."
#endif // WEB
#endif // DESKTOP
if (status == 0) if (status == 0)
{ {
// simply not done yet // simply not done yet
@ -2973,10 +3141,16 @@ void frame(void)
{ {
// done! we can get the string // done! we can get the string
char sentence_cstr[MAX_SENTENCE_LENGTH] = { 0 }; char sentence_cstr[MAX_SENTENCE_LENGTH] = { 0 };
#ifdef WEB
EM_ASM( { EM_ASM( {
let generation = get_generation_request_content($0); let generation = get_generation_request_content($0);
stringToUTF8(generation, $1, $2); stringToUTF8(generation, $1, $2);
}, it->gen_request_id, sentence_cstr, ARRLEN(sentence_cstr) - 1); // I think minus one for null terminator... }, it->gen_request_id, sentence_cstr, ARRLEN(sentence_cstr) - 1); // I think minus one for null terminator...
#endif
#ifdef DESKTOP
memcpy(sentence_cstr, get_by_id(it->gen_request_id)->generated, get_by_id(it->gen_request_id)->generated_length);
#endif
MD_String8 sentence_str = MD_S8CString(sentence_cstr); MD_String8 sentence_str = MD_S8CString(sentence_cstr);
@ -2997,9 +3171,14 @@ void frame(void)
MD_ReleaseScratch(scratch); MD_ReleaseScratch(scratch);
#ifdef WEB
EM_ASM( { EM_ASM( {
done_with_generation_request($0); done_with_generation_request($0);
}, it->gen_request_id); }, it->gen_request_id);
#endif
#ifdef DESKTOP
done_with_request(it->gen_request_id);
#endif
} }
else if (status == 2) else if (status == 2)
{ {
@ -3008,7 +3187,7 @@ void frame(void)
Action to_perform = {0}; Action to_perform = {0};
MD_String8 speech_mdstring = MD_S8Lit("I'm not sure..."); MD_String8 speech_mdstring = MD_S8Lit("I'm not sure...");
memcpy(to_perform.speech, speech_mdstring.str, speech_mdstring.size); memcpy(to_perform.speech, speech_mdstring.str, speech_mdstring.size);
to_perform.speech_length = speech_mdstring.size; to_perform.speech_length = (int)speech_mdstring.size;
perform_action(it, to_perform); perform_action(it, to_perform);
} }
else if (status == -1) else if (status == -1)
@ -3023,7 +3202,6 @@ void frame(void)
} }
} }
} }
#endif
if (fabsf(it->vel.x) > 0.01f) if (fabsf(it->vel.x) > 0.01f)
@ -3578,99 +3756,39 @@ void frame(void)
else else
{ {
MD_String8 post_request_body = MD_S8Fmt(scratch.arena, "|%.*s", MD_S8VArg(prompt_str)); MD_String8 post_request_body = MD_S8Fmt(scratch.arena, "|%.*s", MD_S8VArg(prompt_str));
it->gen_request_id = make_generation_request(post_request_body);
}
#define WinAssertWithErrorCode(X) if( !( X ) ) { unsigned int error = GetLastError(); Log("Error %u in %s\n", error, #X); assert(false); } // something to mock
if(ai_response.size > 0)
{
Log("Mocking...\n");
HINTERNET hSession = WinHttpOpen(L"PlayGPT winhttp backend", WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0); Action a = {0};
WinAssertWithErrorCode(hSession); MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__));
if(succeeded)
LPCWSTR windows_server_name = windows_string(MD_S8Lit(SERVER_DOMAIN));
HINTERNET hConnect = WinHttpConnect(hSession, windows_server_name, SERVER_PORT, 0);
WinAssertWithErrorCode(hConnect);
int security_flags = 0;
if(IS_SERVER_SECURE)
{ {
security_flags = WINHTTP_FLAG_SECURE; error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a);
} }
HINTERNET hRequest = WinHttpOpenRequest(hConnect, L"POST", L"completion", 0, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, security_flags); if(mocking_the_ai_response)
WinAssertWithErrorCode(hRequest);
// @IMPORTANT @TODO the windows_string allocates on the frame arena, but
// according to https://learn.microsoft.com/en-us/windows/win32/api/winhttp/nf-winhttp-winhttpsendrequest
// the buffer needs to remain available as long as the http request is running, so to make this async and do the loading thing need some other way to allocate the winndows string.... arenas bad?
succeeded = WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, (LPVOID)post_request_body.str, (DWORD)post_request_body.size, (DWORD)post_request_body.size, 0);
if(!succeeded)
{ {
Log("Couldn't do the web: %u\n", GetLastError()); assert(succeeded);
assert(error_message.size == 0);
perform_action(it, a);
} }
if(succeeded) else
{ {
WinAssertWithErrorCode(WinHttpReceiveResponse(hRequest, 0)); if(succeeded)
DWORD status_code;
DWORD status_code_size = sizeof(status_code);
WinAssertWithErrorCode(WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &status_code, &status_code_size, WINHTTP_NO_HEADER_INDEX));
Log("Status code: %u\n", status_code);
DWORD dwSize = 0;
MD_String8List received_data_list = {0};
do
{ {
dwSize = 0; if (error_message.size == 0)
WinAssertWithErrorCode(WinHttpQueryDataAvailable(hRequest, &dwSize));
if(dwSize == 0)
{ {
Log("Didn't get anything back.\n"); perform_action(it, a);
} }
else else
{ {
MD_u8* out_buffer = MD_PushArray(scratch.arena, MD_u8, dwSize + 1); Log("There was an error with the AI: %.*s", MD_S8VArg(error_message));
DWORD dwDownloaded = 0; remember_error(it, error_message);
WinAssertWithErrorCode(WinHttpReadData(hRequest, (LPVOID)out_buffer, dwSize, &dwDownloaded));
out_buffer[dwDownloaded - 1] = '\0';
Log("Got this from http, size %d: %s\n", dwDownloaded, out_buffer);
MD_S8ListPush(scratch.arena, &received_data_list, MD_S8(out_buffer, dwDownloaded));
} }
} while (dwSize > 0);
MD_String8 received_data = MD_S8ListJoin(scratch.arena, received_data_list, &(MD_StringJoin){0});
ai_response = MD_S8Substring(received_data, 1, received_data.size);
}
else
{
it->perceptions_dirty = true;
}
}
Action a = {0};
MD_String8 error_message = MD_S8Lit("Something really bad happened bro. File " STRINGIZE(__FILE__) " Line " STRINGIZE(__LINE__));
if(succeeded)
{
error_message = parse_chatgpt_response(scratch.arena, it, ai_response, &a);
}
if(mocking_the_ai_response)
{
assert(succeeded);
assert(error_message.size == 0);
perform_action(it, a);
}
else
{
if(succeeded)
{
if (error_message.size == 0)
{
perform_action(it, a);
}
else
{
Log("There was an error with the AI: %.*s", MD_S8VArg(error_message));
remember_error(it, error_message);
} }
} }
} }
@ -3919,12 +4037,10 @@ void frame(void)
PROFILE_SCOPE("entity rendering") PROFILE_SCOPE("entity rendering")
ENTITIES_ITER(gs.entities) ENTITIES_ITER(gs.entities)
{ {
#ifdef WEB
if (it->gen_request_id != 0) if (it->gen_request_id != 0)
{ {
draw_quad((DrawParams) { true, quad_centered(AddV2(it->pos, V2(0.0, 50.0)), V2(100.0, 100.0)), IMG(image_thinking), WHITE }); draw_quad((DrawParams) { true, quad_centered(AddV2(it->pos, V2(0.0, 50.0)), V2(100.0, 100.0)), IMG(image_thinking), WHITE });
} }
#endif
Color col = LerpV4(WHITE, it->damage, RED); Color col = LerpV4(WHITE, it->damage, RED);
if (it->is_npc) if (it->is_npc)
@ -4540,6 +4656,8 @@ void frame(void)
void cleanup(void) void cleanup(void)
{ {
free(fontBuffer); free(fontBuffer);
MD_ArenaRelease(frame_arena);
MD_ArenaRelease(persistent_arena);
sg_shutdown(); sg_shutdown();
hmfree(imui_state); hmfree(imui_state);
Log("Cleaning up\n"); Log("Cleaning up\n");

@ -236,9 +236,7 @@ typedef struct Entity
NPCPlayerStanding standing; NPCPlayerStanding standing;
NpcKind npc_kind; NpcKind npc_kind;
PathCacheHandle cached_path; PathCacheHandle cached_path;
#ifdef WEB
int gen_request_id; int gen_request_id;
#endif
bool walking; bool walking;
double shotgun_timer; double shotgun_timer;
bool moved; bool moved;

17
thirdparty/md.h vendored

@ -362,12 +362,13 @@
(zchk(f)?\ (zchk(f)?\
((f)=(l)=(n),zset((n)->next),zset((n)->prev)):\ ((f)=(l)=(n),zset((n)->next),zset((n)->prev)):\
((n)->prev=(l),(l)->next=(n),(l)=(n),zset((n)->next))) ((n)->prev=(l),(l)->next=(n),(l)=(n),zset((n)->next)))
#define MD_DblRemove_NPZ(f,l,n,next,prev,zset) (((f)==(n)?\ #define MD_DblRemove_NPZ(f,l,n,next,prev,zchk,zset) (((f)==(n))?\
((f)=(f)->next,zset((f)->prev)):\ ((f)=(f)->next, (zchk(f) ? (zset(l)) : zset((f)->prev))):\
(l)==(n)?\ ((l)==(n))?\
((l)=(l)->prev,zset((l)->next)):\ ((l)=(l)->prev, (zchk(l) ? (zset(f)) : zset((l)->next))):\
((n)->next->prev=(n)->prev,\ ((zchk((n)->next) ? (0) : ((n)->next->prev=(n)->prev)),\
(n)->prev->next=(n)->next))) (zchk((n)->prev) ? (0) : ((n)->prev->next=(n)->next))))
// compositions // compositions
#define MD_QueuePush(f,l,n) MD_QueuePush_NZ(f,l,n,next,MD_CheckNull,MD_SetNull) #define MD_QueuePush(f,l,n) MD_QueuePush_NZ(f,l,n,next,MD_CheckNull,MD_SetNull)
@ -376,11 +377,11 @@
#define MD_StackPop(f) MD_StackPop_NZ(f,next,MD_CheckNull) #define MD_StackPop(f) MD_StackPop_NZ(f,next,MD_CheckNull)
#define MD_DblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull) #define MD_DblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull)
#define MD_DblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNull,MD_SetNull) #define MD_DblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNull,MD_SetNull)
#define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNull) #define MD_DblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNull,MD_SetNull)
#define MD_NodeDblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil) #define MD_NodeDblPushBack(f,l,n) MD_DblPushBack_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil)
#define MD_NodeDblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNil,MD_SetNil) #define MD_NodeDblPushFront(f,l,n) MD_DblPushBack_NPZ(l,f,n,prev,next,MD_CheckNil,MD_SetNil)
#define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_SetNil) #define MD_NodeDblRemove(f,l,n) MD_DblRemove_NPZ(f,l,n,next,prev,MD_CheckNil,MD_SetNil)
//~ Memory Operations //~ Memory Operations

Loading…
Cancel
Save