diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e271022..af3a609 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ target_sources(${DUSK_TARGET_NAME} # Subdirs add_subdirectory(assert) +add_subdirectory(client) add_subdirectory(console) add_subdirectory(error) add_subdirectory(server) diff --git a/src/client/CMakeLists.txt b/src/client/CMakeLists.txt index e2217ee..f37255e 100644 --- a/src/client/CMakeLists.txt +++ b/src/client/CMakeLists.txt @@ -9,4 +9,5 @@ target_sources(${DUSK_TARGET_NAME} client.c ) -# Subdirs \ No newline at end of file +# Subdirs +add_subdirectory(networked) \ No newline at end of file diff --git a/src/client/client.c b/src/client/client.c index e69de29..1bfd5ea 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -0,0 +1,82 @@ +/** + * Copyright (c) 2025 Dominic Masters + * + * This software is released under the MIT License. + * https://opensource.org/licenses/MIT + */ + +#include "client.h" +#include "assert/assert.h" +#include "util/memory.h" +#include "console/console.h" +#include "server/server.h" + +client_t CLIENT; + +void cmdJoin(const consolecmdexec_t *exec) { + clientconnect_t connect; + + if(exec->argc > 0) { + connect.type = CLIENT_TYPE_NETWORKED; + connect.networked.port = SERVER_DEFAULT_PORT; + } else { + connect.type = CLIENT_TYPE_SINGLE_PLAYER; + } + + errorret_t ret = clientConnect(connect); + if(ret == ERROR_OK) { + consolePrint("Connected to server"); + } else { + consolePrint("Failed to connect to server: %s", errorString()); + } + errorFlush(); +} + +void cmdLeave(const consolecmdexec_t *exec) { + clientDisconnect(); +} + +void clientInit() { + memoryZero(&CLIENT, sizeof(client_t)); + + consoleRegCmd("join", cmdJoin); + consoleRegCmd("leave", cmdLeave); +} + +errorret_t clientConnect(const clientconnect_t connect) { + errorret_t ret; + if(CLIENT.state != CLIENT_STATE_DISCONNECTED) { + return error("Client is already connected"); + } + + CLIENT.type = connect.type; + + switch(connect.type) { + case CLIENT_TYPE_NETWORKED: + ret = networkedClientConnect(&CLIENT, connect.networked); + break; + + default: + assertUnreachable("Invalid client type"); + } + + if(ret != ERROR_OK) CLIENT.state = CLIENT_STATE_DISCONNECTED; + return ret; +} + +void clientDisconnect() { + if(CLIENT.state == CLIENT_STATE_DISCONNECTED) return; + + switch(CLIENT.type) { + case CLIENT_TYPE_NETWORKED: + networkedClientDisconnect(&CLIENT); + break; + + default: + assertUnreachable("Invalid client type"); + } +} + +void clientDispose() { + clientDisconnect(); +} \ No newline at end of file diff --git a/src/client/client.h b/src/client/client.h index 35a89f9..cbbefca 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -6,3 +6,59 @@ */ #pragma once +#include "client/networked/networkedclient.h" + +typedef enum { + CLIENT_TYPE_NETWORKED, + CLIENT_TYPE_SINGLE_PLAYER, +} clienttype_t; + +typedef enum { + CLIENT_STATE_DISCONNECTED, + CLIENT_STATE_CONNECTING, + CLIENT_STATE_CONNECTED, + CLIENT_STATE_DISCONNECTING, +} clientstate_t; + +typedef struct clientconnect_s { + clienttype_t type; + union { + networkedclientconnect_t networked; + }; +} clientconnect_t; + +typedef struct client_s { + clientstate_t state; + clienttype_t type; + + union { + networkedclient_t networked; + }; +} client_t; + +extern client_t CLIENT; + +/** + * Initializes the client. + * + * @return Error code indicating success or failure. + */ +void clientInit(); + +/** + * Connects to a server. + * + * @param connect Connection information. + * @return Error code indicating success or failure. + */ +errorret_t clientConnect(const clientconnect_t connect); + +/** + * Disconnects the client from the server. + */ +void clientDisconnect(); + +/** + * Cleans up the client resources. + */ +void clientDispose(); \ No newline at end of file diff --git a/src/client/networked/CMakeLists.txt b/src/client/networked/CMakeLists.txt new file mode 100644 index 0000000..b0ac150 --- /dev/null +++ b/src/client/networked/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright (c) 2025 Dominic Masters +# +# This software is released under the MIT License. +# https://opensource.org/licenses/MIT + +# Sources +target_sources(${DUSK_TARGET_NAME} + PRIVATE + networkedclient.c +) \ No newline at end of file diff --git a/src/client/networked/networkedclient.c b/src/client/networked/networkedclient.c new file mode 100644 index 0000000..02f8de5 --- /dev/null +++ b/src/client/networked/networkedclient.c @@ -0,0 +1,261 @@ +/** + * Copyright (c) 2025 Dominic Masters + * + * This software is released under the MIT License. + * https://opensource.org/licenses/MIT + */ + +#include "client/client.h" +#include "assert/assert.h" +#include "console/console.h" + +errorret_t networkedClientConnect( + client_t *client, + const networkedclientconnect_t connInfo +) { + int32_t ret; + packet_t packet; + errorret_t err; + char_t *ip = "127.0.0.1"; + + + assertNotNull(client, "Client is NULL"); + assertTrue(client->type == CLIENT_TYPE_NETWORKED, "Client is not networked"); + assertIsMainThread("Client connect must be on main thread"); + + client->state = CLIENT_STATE_CONNECTING; + consolePrint("Connecting to server %s:%d", ip, connInfo.port); + + // Create a socket + client->networked.socket = socket(AF_INET, SOCK_STREAM, 0); + if(client->networked.socket < 0) { + return error("Failed to create socket %s", strerror(errno)); + } + + // Set ip address and port + client->networked.address.sin_family = AF_INET; + client->networked.address.sin_port = htons(connInfo.port); + client->networked.address.sin_addr.s_addr = inet_addr(ip); + + ret = inet_pton(AF_INET, ip, &client->networked.address.sin_addr); + if(ret <= 0) { + close(client->networked.socket); + return error("Invalid or bad IP address %s: %s", ip, strerror(errno)); + } + + // Connect to the server + ret = connect( + client->networked.socket, + (struct sockaddr *)&client->networked.address, + sizeof(client->networked.address) + ); + if(ret < 0) { + close(client->networked.socket); + switch(errno) { + case ECONNREFUSED: + return error("Failed to connect: Connection refused"); + case ETIMEDOUT: + return error("Failed to connect: Connection timed out"); + case ENETUNREACH: + return error("Failed to connect: Network unreachable"); + default: + return error("Failed to connect: Unknown error"); + } + } + + // Send the version + { + const char_t *message = "DUSK|"DUSK_VERSION; + ssize_t sent = send( + client->networked.socket, + message, + strlen(message), + 0 + ); + } + + // We should now receive a welcome packet + err = networkedClientReadPacket(client, &packet); + if(err) return err; + switch(packet.type) { + case PACKET_TYPE_DISCONNECT: + err = packetDisconnectClient(&packet); + if(err) return err; + break; + + case PACKET_TYPE_WELCOME: + err = packetWelcomeClient(&packet); + if(err) return err; + break; + + default: + return error("Server did not send welcome message."); + } + + // Connection was established, hand off to thread + ret = pthread_create( + &client->networked.thread, + NULL, + clientThread, + client + ); + + if(ret != 0) { + close(client->networked.socket); + return error("Failed to create client thread %s", strerror(errno)); + } + + // Wait for the thread to start + while(client->state == CLIENT_STATE_CONNECTING) { + usleep(1000); + } + + return ERROR_OK; +} + +void networkedClientDisconnect(client_t *client) { + assertNotNull(client, "Client is NULL"); + assertIsMainThread("Client disconnect must be on main thread"); + assertTrue(client->type == CLIENT_TYPE_NETWORKED, "Client is not networked"); + assertTrue(client->state == CLIENT_STATE_CONNECTED, "Client not connected"); + + client->state = CLIENT_STATE_DISCONNECTING; + + int32_t maxAttempts = 0; + while(client->state == CLIENT_STATE_DISCONNECTING) { + usleep(1000); + maxAttempts++; + if(maxAttempts > 15) { + consolePrint("Client disconnect timed out, force closing"); + break; + } + } + + client->state = CLIENT_STATE_DISCONNECTED; + if(client->networked.thread) { + pthread_join(client->networked.thread, NULL); + client->networked.thread = 0; + } + + if(client->networked.socket) { + shutdown(client->networked.socket, SHUT_RDWR); + close(client->networked.socket); + client->networked.socket = 0; + } + + consolePrint("Client disconnected"); +} + +errorret_t networkedClientWrite( + const client_t *client, + const uint8_t *data, + const size_t len +) { + assertNotNull(client, "Client is NULL"); + assertNotNull(data, "Data is NULL"); + assertTrue(len > 0, "Data length is 0"); + assertTrue(client->type == CLIENT_TYPE_NETWORKED, "Client is not networked"); + + if(client->state == CLIENT_STATE_DISCONNECTED) { + return error("Client is disconnected"); + } + + ssize_t sent = send(client->networked.socket, data, len, 0); + if(sent < 0) return error("Failed to send data"); + return ERROR_OK; +} + +errorret_t networkedClientWritePacket( + const client_t *client, + const packet_t *packet +) { + assertNotNull(packet, "Packet is NULL"); + assertTrue(packet->type != PACKET_TYPE_INVALID, "Packet type is INVALID"); + assertTrue(packet->length > 0, "Packet length is 0"); + assertTrue( + packet->length <= sizeof(packetdata_t), + "Packet length is too large (1)" + ); + + size_t fullSize = sizeof(packet_t) - sizeof(packet->data) + packet->length; + assertTrue(fullSize <= sizeof(packet_t), "Packet size is too large (2)"); + return networkedClientWrite(client, (const uint8_t *)packet, fullSize); +} + +errorret_t networkedClientReadPacket( + const client_t *client, + packet_t *packet +) { + uint8_t buffer[sizeof(packet_t)]; + + assertNotNull(client, "Client is NULL"); + assertNotNull(packet, "Packet is NULL"); + assertTrue(client->type == CLIENT_TYPE_NETWORKED, "Client is not networked"); + + if(client->state == CLIENT_STATE_DISCONNECTED) { + return error("Client is disconnected"); + } + + // Read the packet header + ssize_t read = recv( + client->networked.socket, + buffer, + sizeof(packettype_t), + 0 + ); + if(read != sizeof(packettype_t)) { + return error("Failed to read packet header %s", strerror(errno)); + } + + packet->type = *(packettype_t *)buffer; + + // Read the packet length + read = recv( + client->networked.socket, + buffer, + sizeof(uint32_t), + 0 + ); + if(read != sizeof(uint32_t)) { + return error("Failed to read packet length %s", strerror(errno)); + } + if(read > sizeof(packetdata_t)) { + return error("Packet length is too large"); + } + packet->length = *(uint32_t *)buffer; + + // Now, read the packet data + read = recv( + client->networked.socket, + (uint8_t *)&packet->data, + packet->length, + 0 + ); + if(read != packet->length) { + return error("Failed to read packet data %s", strerror(errno)); + } + return ERROR_OK; +} + +void * clientThread(void *arg) { + assertNotNull(arg, "Client thread argument is NULL"); + assertNotMainThread("Client thread must not be on main thread"); + + client_t *client = (client_t *)arg; + assertTrue( + client->type == CLIENT_TYPE_NETWORKED, + "Client thread argument is not networked" + ); + assertTrue( + client->state == CLIENT_STATE_CONNECTING, + "Client thread argument is not connecting" + ); + + client->state = CLIENT_STATE_CONNECTED; + + while(client->state == CLIENT_STATE_CONNECTED) { + usleep(1000 * 1000); + } + + printf("Client thread exiting\n"); +} \ No newline at end of file diff --git a/src/client/networked/networkedclient.h b/src/client/networked/networkedclient.h new file mode 100644 index 0000000..b6055f9 --- /dev/null +++ b/src/client/networked/networkedclient.h @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2025 Dominic Masters + * + * This software is released under the MIT License. + * https://opensource.org/licenses/MIT + */ + +#pragma once +#include "error/error.h" +#include +#include "server/packet/packet.h" + +typedef struct client_s client_t; +typedef struct clientconnect_s clientconnect_t; + +typedef struct { + uint16_t port; +} networkedclientconnect_t; + +typedef struct { + int32_t socket; + struct sockaddr_in address; + pthread_t thread; +} networkedclient_t; + +/** + * Connects to a networked server. + * + * @param client Pointer to the client structure. + * @param connect Connection information. + */ +errorret_t networkedClientConnect( + client_t *client, + const networkedclientconnect_t connect +); + +/** + * Closes the connection to a networked server. + * + * @param client Pointer to the client structure. + */ +void networkedClientDisconnect(client_t *client); + +/** + * Writes data to the networked server. + * + * @param client Pointer to the client structure. + * @param data Data to write. + * @param len Length of the data. + * @return Error code. + */ +errorret_t networkedClientWrite( + const client_t *client, + const uint8_t *data, + const size_t len +); + +/** + * Writes a packet to the networked server. + * + * @param client Pointer to the client structure. + * @param packet Pointer to the packet structure. + * @return Error code. + */ +errorret_t networkedClientWritePacket( + const client_t *client, + const packet_t *packet +); + +/** + * Reads a packet from the networked server. + * + * @param client Pointer to the client structure. + * @param packet Pointer to the packet structure to read into. + * @return Error code. + */ +errorret_t networkedClientReadPacket(const client_t *client, packet_t *packet); + +/** + * Thread function for handling networked client connections. + * + * @param arg Pointer to the client structure. + * @return NULL. + */ +void * networkedClientThread(void *arg); \ No newline at end of file diff --git a/src/main.c b/src/main.c index b7fb0cd..9d39454 100644 --- a/src/main.c +++ b/src/main.c @@ -6,6 +6,7 @@ */ #include "console/console.h" +#include "client/client.h" #include "server/server.h" #include "util/string.h" #include "assert/assert.h" @@ -16,41 +17,13 @@ void cmdExit(const consolecmdexec_t *exec) { exitRequested = true; } -void cmdServe(const consolecmdexec_t *exec) { - uint16_t port = 3030; - if(exec->argc != 0) { - if(!stringToU16(exec->argv[0], &port)) { - consolePrint("Invalid port number: %s", exec->argv[0]); - return; - } - } - - errorret_t ret = serverStart((serverstart_t){ - .type = SERVER_TYPE_NETWORKED, - .networked = { - .port = 3030 - } - }); - - if(ret != ERROR_OK) { - consolePrint("Failed to start server: %s", errorString()); - errorFlush(); - return; - } -} - -void cmdClose(const consolecmdexec_t *exec) { - serverStop(); -} - int main(void) { assertInit(); consoleInit(); + clientInit(); serverInit(); consoleRegCmd("exit", cmdExit); - consoleRegCmd("serve", cmdServe); - consoleRegCmd("close", cmdClose); InitWindow(1280, 720, DUSK_NAME); @@ -68,7 +41,9 @@ int main(void) { } CloseWindow(); + serverDispose(); + clientDispose(); return EXIT_SUCCESS; } \ No newline at end of file diff --git a/src/server/networked/networkedserver.h b/src/server/networked/networkedserver.h index 6f44ccb..40af3a5 100644 --- a/src/server/networked/networkedserver.h +++ b/src/server/networked/networkedserver.h @@ -10,6 +10,10 @@ #include #include +typedef struct { + uint16_t port; +} networkedserverstart_t; + typedef struct { int socket; struct sockaddr_in address; diff --git a/src/server/networked/networkedserverclient.c b/src/server/networked/networkedserverclient.c index 7dea028..d1701a0 100644 --- a/src/server/networked/networkedserverclient.c +++ b/src/server/networked/networkedserverclient.c @@ -207,6 +207,7 @@ void * networkedServerClientThread(void *arg) { return NULL; } + buffer[read] = '\0'; // Null-terminate the string if(strncmp(buffer, expecting, strlen(expecting)) != 0) { packetDisconnectCreate(&packet, PACKET_DISCONNECT_REASON_INVALID_VERSION); @@ -238,17 +239,16 @@ void * networkedServerClientThread(void *arg) { if(SERVER.state != SERVER_STATE_RUNNING) { + packetDisconnectCreate( + &packet, + PACKET_DISCONNECT_REASON_SERVER_SHUTDOWN + ); + networkedServerClientWritePacket(client, &packet); + if(errorCheck()) errorPrint(); networkedServerClientCloseOnThread(client, "Server is shutting down"); break; } } - - packetDisconnectCreate( - &packet, - PACKET_DISCONNECT_REASON_SERVER_SHUTDOWN - ); - networkedServerClientWritePacket(client, &packet); - if(errorCheck()) errorPrint(); client->state = SERVER_CLIENT_STATE_DISCONNECTED; return NULL; diff --git a/src/server/packet/packetdisconnect.c b/src/server/packet/packetdisconnect.c index c6205e7..2986f01 100644 --- a/src/server/packet/packetdisconnect.c +++ b/src/server/packet/packetdisconnect.c @@ -7,6 +7,7 @@ #include "packet.h" #include "util/memory.h" +#include "assert/assert.h" void packetDisconnectCreate( packet_t *packet, @@ -14,4 +15,32 @@ void packetDisconnectCreate( ) { packetInit(packet, PACKET_TYPE_DISCONNECT, sizeof(packetdisconnect_t)); packet->data.disconnect.reason = reason; +} + +errorret_t packetDisconnectClient(packet_t *packet) { + assertNotNull(packet, "Packet is NULL"); + assertTrue( + packet->type == PACKET_TYPE_DISCONNECT, + "Packet type is not DISCONNECT" + ); + + if(packet->length != sizeof(packetdisconnect_t)) { + return error("Disconnect packet length is not correct"); + } + + packetdisconnect_t *data = (packetdisconnect_t *)&packet->data; + switch(data->reason) { + case PACKET_DISCONNECT_REASON_UNKNOWN: + return error("Server disconnected: Unknown reason"); + case PACKET_DISCONNECT_REASON_INVALID_VERSION: + return error("Server disconnected: Invalid version"); + case PACKET_DISCONNECT_REASON_MALFORMED_PACKET: + return error("Server disconnected: Malformed packet"); + case PACKET_DISCONNECT_REASON_SERVER_FULL: + return error("Server disconnected: Server full"); + case PACKET_DISCONNECT_REASON_SERVER_SHUTDOWN: + return error("Server disconnected: Server shutdown"); + default: + return error("Server disconnected: Unknown reason"); + } } \ No newline at end of file diff --git a/src/server/packet/packetdisconnect.h b/src/server/packet/packetdisconnect.h index 4b4e6d5..fa16693 100644 --- a/src/server/packet/packetdisconnect.h +++ b/src/server/packet/packetdisconnect.h @@ -6,7 +6,7 @@ */ #pragma once -#include "dusk.h" +#include "error/error.h" typedef struct packet_s packet_t; @@ -31,4 +31,12 @@ typedef struct { void packetDisconnectCreate( packet_t *packet, const packetdisconnectreason_t reason -); \ No newline at end of file +); + +/** + * Handles a disconnect packet received FROM a server INTO a client. + * + * @param packet Pointer to the packet structure to handle. + * @return ERROR_OK on success, or an error code on failure. + */ +errorret_t packetDisconnectClient(packet_t *packet); \ No newline at end of file diff --git a/src/server/packet/packetwelcome.c b/src/server/packet/packetwelcome.c index f571808..0fadbfb 100644 --- a/src/server/packet/packetwelcome.c +++ b/src/server/packet/packetwelcome.c @@ -7,10 +7,30 @@ #include "packet.h" #include "util/memory.h" +#include "assert/assert.h" void packetWelcomeCreate(packet_t *packet) { packetInit(packet, PACKET_TYPE_WELCOME, PACKET_WELCOME_SIZE); memoryCopy( packet->data.welcome.dusk, PACKET_WELCOME_STRING, PACKET_WELCOME_SIZE ); +} + +errorret_t packetWelcomeClient(packet_t *packet) { + assertNotNull(packet, "Packet is NULL"); + assertTrue(packet->type == PACKET_TYPE_WELCOME, "Packet type is not WELCOME"); + + if(packet->length != PACKET_WELCOME_SIZE) { + return error("Welcome packet length is not %d", PACKET_WELCOME_SIZE); + } + + if( + memoryCompare( + packet->data.welcome.dusk, PACKET_WELCOME_STRING, PACKET_WELCOME_SIZE + ) != 0 + ) { + return error("Welcome packet data is not %s", PACKET_WELCOME_STRING); + } + + return ERROR_OK; } \ No newline at end of file diff --git a/src/server/packet/packetwelcome.h b/src/server/packet/packetwelcome.h index ed63405..dedc047 100644 --- a/src/server/packet/packetwelcome.h +++ b/src/server/packet/packetwelcome.h @@ -6,7 +6,7 @@ */ #pragma once -#include "dusk.h" +#include "error/error.h" typedef struct packet_s packet_t; @@ -22,4 +22,12 @@ typedef struct { * * @param packet Pointer to the packet structure to initialize. */ -void packetWelcomeCreate(packet_t *packet); \ No newline at end of file +void packetWelcomeCreate(packet_t *packet); + +/** + * Handles a welcome packet received FROM a server INTO a client. + * + * @param packet Pointer to the packet structure to handle. + * @return ERROR_OK on success, or an error code on failure. + */ +errorret_t packetWelcomeClient(packet_t *packet); \ No newline at end of file diff --git a/src/server/server.c b/src/server/server.c index 284a7c5..18084e4 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -9,15 +9,61 @@ #include "util/memory.h" #include "assert/assert.h" #include "console/console.h" +#include "util/string.h" server_t SERVER; +void cmdStart(const consolecmdexec_t *exec) { + serverstart_t start; + + if(exec->argc > 0) { + if(stringCompare(exec->argv[1], "0") == 0) { + start.type = SERVER_TYPE_SINGLE_PLAYER; + } else if(stringCompare(exec->argv[0], "1") == 0) { + start.type = SERVER_TYPE_NETWORKED; + } else { + consolePrint("Invalid server type: %s", exec->argv[0]); + return; + } + } else { + start.type = SERVER_TYPE_SINGLE_PLAYER; + } + + if(exec->argc > 1) { + if(start.type == SERVER_TYPE_NETWORKED) { + if(!stringToU16(exec->argv[1], &start.networked.port)) { + consolePrint("Invalid port number: %s", exec->argv[0]); + return; + } + } + } else { + if(start.type == SERVER_TYPE_NETWORKED) { + start.networked.port = SERVER_DEFAULT_PORT; + } + } + + errorret_t ret = serverStart(start); + if(ret != ERROR_OK) { + consolePrint("Failed to start server: %s", errorString()); + errorFlush(); + return; + } +} + +void cmdClose(const consolecmdexec_t *exec) { + serverStop(); +} + void serverInit() { memoryZero(&SERVER, sizeof(server_t)); SERVER.state = SERVER_STATE_STOPPED; + + consoleRegCmd("start", cmdStart); + consoleRegCmd("close", cmdClose); } errorret_t serverStart(const serverstart_t start) { + errorret_t ret; assertIsMainThread("Server start must be on main thread"); // Do not start a running server. @@ -30,12 +76,15 @@ errorret_t serverStart(const serverstart_t start) { // Hand off to relevant server type to start. switch(start.type) { case SERVER_TYPE_NETWORKED: - return networkedServerStart(&SERVER, start); + ret = networkedServerStart(&SERVER, start); break; default: assertUnreachable("Invalid server type"); } + + if(ret != ERROR_OK) SERVER.state = SERVER_STATE_STOPPED; + return ret; } uint8_t serverGetClientCount() { diff --git a/src/server/server.h b/src/server/server.h index 13395ab..13e83ce 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -10,6 +10,7 @@ #include "server/networked/networkedserver.h" #define SERVER_MAX_CLIENTS 32 +#define SERVER_DEFAULT_PORT 3030 typedef enum { SERVER_STATE_STOPPED, @@ -26,9 +27,7 @@ typedef enum { typedef struct serverstart_s { servertype_t type; union { - struct { - uint16_t port; - } networked; + networkedserverstart_t networked; }; } serverstart_t; diff --git a/src/server/serverclient.c b/src/server/serverclient.c index 9f54212..71600ef 100644 --- a/src/server/serverclient.c +++ b/src/server/serverclient.c @@ -14,6 +14,7 @@ errorret_t serverClientAccept( serverclient_t *client, const serverclientaccept_t accept ) { + errorret_t ret; memoryZero(client, sizeof(serverclient_t)); assertNotNull(accept.server, "Server is NULL"); assertNotMainThread("Server client accept must not be main thread"); @@ -22,11 +23,14 @@ errorret_t serverClientAccept( switch(accept.server->type) { case SERVER_TYPE_NETWORKED: - return networkedServerClientAccept(client, accept); + ret = networkedServerClientAccept(client, accept); default: assertUnreachable("Unknown server type"); } + + if(ret != ERROR_OK) memoryZero(client, sizeof(serverclient_t)); + return ret; } void serverClientClose(serverclient_t *client) { diff --git a/src/util/memory.c b/src/util/memory.c index 78c5b26..8387b48 100644 --- a/src/util/memory.c +++ b/src/util/memory.c @@ -57,4 +57,15 @@ void memoryMove(void *dest, const void *src, const size_t size) { assertTrue(size > 0, "Cannot move 0 bytes of memory."); assertTrue(dest != src, "Cannot move memory to itself."); memmove(dest, src, size); +} + +ssize_t memoryCompare( + const void *a, + const void *b, + const size_t size +) { + assertNotNull(a, "Cannot compare NULL memory."); + assertNotNull(b, "Cannot compare NULL memory."); + assertTrue(size > 0, "Cannot compare 0 bytes of memory."); + return memcmp(a, b, size); } \ No newline at end of file diff --git a/src/util/memory.h b/src/util/memory.h index 153a463..ee50d3d 100644 --- a/src/util/memory.h +++ b/src/util/memory.h @@ -73,4 +73,18 @@ void memoryCopyRangeSafe( * @param src The source to move from. * @param size The size of the memory to move. */ -void memoryMove(void *dest, const void *src, const size_t size); \ No newline at end of file +void memoryMove(void *dest, const void *src, const size_t size); + +/** + * Compares memory. + * + * @param a The first memory to compare. + * @param b The second memory to compare. + * @param size The size of the memory to compare. + * @return 0 if the memory is equal, < 0 if a < b, > 0 if a > b. + */ +ssize_t memoryCompare( + const void *a, + const void *b, + const size_t size +); \ No newline at end of file