/**
 * Copyright (c) 2025 Dominic Masters
 * 
 * This software is released under the MIT License.
 * https://opensource.org/licenses/MIT
 */

#include "server/server.h"
#include "assert/assert.h"
#include "util/memory.h"
#include "console/console.h"
#include <sys/socket.h>

errorret_t networkedServerClientAccept(
  serverclient_t *client,
  const serverclientaccept_t accept
) {
  assertNotNull(client, "Client is NULL");
  assertNotNull(accept.server, "Server is NULL");
  assertTrue(
    accept.server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );

  client->state = SERVER_CLIENT_STATE_ACCEPTING;
  client->networked.socket = accept.networked.socket;

  // Set timeout to 8 seconds
  client->networked.timeout.tv_sec = 8;
  client->networked.timeout.tv_usec = 0;

  // Initialize mutexs
  pthread_mutex_init(&client->networked.readLock, NULL);
  pthread_mutex_init(&client->networked.writeLock, NULL);

  // Create a read thread for the client
  int32_t ret = pthread_create(
    &client->networked.readThread,
    NULL,
    networkedServerClientReadThread,
    client
  );
  if(ret != 0) {
    client->state = SERVER_CLIENT_STATE_DISCONNECTED;
    return error("Failed to create client read thread");
  }
  
  // Set socket timeout
  if(setsockopt(
    client->networked.socket,
    SOL_SOCKET,
    SO_RCVTIMEO,
    &client->networked.timeout,
    sizeof(client->networked.timeout)
  ) < 0) {
    networkedServerClientCloseOnThread(client, "Failed to set socket timeout");
    return error("Failed to set socket timeout");
  }

  return ERROR_OK;
}

void networkedServerClientClose(serverclient_t *client) {
  assertIsMainThread("Server client close must be on main thread.");
  assertNotNull(client, "Client is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );

  // Mark client as disconnecting
  client->state = SERVER_CLIENT_STATE_DISCONNECTING;

  // Wait for the read thread to finish
  pthread_mutex_lock(&client->networked.readLock);
  pthread_mutex_unlock(&client->networked.readLock);
  pthread_join(client->networked.readThread, NULL);
  client->networked.readThread = 0;
  pthread_mutex_destroy(&client->networked.readLock);

  // Signal and wait for the write thread to finish
  pthread_mutex_lock(&client->networked.writeLock);
  pthread_mutex_unlock(&client->networked.writeLock);
  pthread_join(client->networked.writeThread, NULL);
  client->networked.writeThread = 0;
  pthread_mutex_destroy(&client->networked.writeLock);

  // Close the socket
  if(client->networked.socket != -1) {
    close(client->networked.socket);
    client->networked.socket = -1;
  }

  client->state = SERVER_CLIENT_STATE_DISCONNECTED;
  consolePrint("Client %d disconnected.", client->networked.socket);
}

void networkedServerClientCloseOnThread(
  serverclient_t *client,
  const char_t *reason
) {
  assertNotNull(client, "Client is NULL");
  assertNotNull(reason, "Reason is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );
  assertNotMainThread("Client close must not be main thread");

  client->state = SERVER_CLIENT_STATE_DISCONNECTING;

  // Terminate the socket
  close(client->networked.socket);
  client->networked.socket = -1;
  client->networked.readThread = 0;
  consolePrint("Client %d disconnected: %s", client->networked.socket, reason);

  // Mark this client as disconnected so it can be used again.
  client->state = SERVER_CLIENT_STATE_DISCONNECTED;
  pthread_exit(NULL);
}

ssize_t networkedServerClientRead(
  const serverclient_t * client,
  uint8_t *buffer,
  const size_t len
) {
  assertNotNull(client, "Client is NULL");
  assertNotNull(buffer, "Buffer is NULL");
  assertNotNull(client->server, "Server is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );
  assertTrue(len > 0, "Buffer length is 0");

  assertTrue(
    client->state == SERVER_CLIENT_STATE_CONNECTED ||
    client->state == SERVER_CLIENT_STATE_ACCEPTING ||
    client->state == SERVER_CLIENT_STATE_DISCONNECTING,
    "Client is not connected, accepting or disconnecting"
  );

  return recv(client->networked.socket, buffer, len, 0);
}

errorret_t networkedServerClientReadPacket(
  const serverclient_t * client,
  packet_t *packet
) {
  uint8_t buffer[sizeof(packet_t)];
  ssize_t read;

  assertNotNull(client, "Client is NULL");
  assertNotNull(packet, "Packet is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );

  // Read packet ID
  read = networkedServerClientRead(client, buffer, sizeof(packettype_t));
  if(read != sizeof(packettype_t)) {
    return error("Failed to read packet ID");
  }

  packet->type = *(packettype_t *)buffer;
  if(packet->type == PACKET_TYPE_INVALID) {
    return error("Invalid packet type");
  }

  // Read length
  read = networkedServerClientRead(
    client,
    buffer,
    sizeof(uint32_t)
  );

  if(read != sizeof(uint32_t)) {
    return error("Failed to read packet length");
  }

  packet->length = *(uint32_t *)buffer;
  if(packet->length > sizeof(packetdata_t)) {
    return error("Packet length is too large");
  }

  // Read data
  read = networkedServerClientRead(
    client,
    (uint8_t *)&packet->data,
    packet->length
  );
  if(read != packet->length) {
    return error("Failed to read packet data");
  }

  return ERROR_OK;
}

errorret_t networkedServerClientWrite(
  serverclient_t * client,
  const uint8_t *data,
  const size_t len
) {
  assertNotNull(client, "Client is NULL");
  assertNotNull(data, "Data is NULL");
  assertTrue(len > 0, "Data length is 0");
  assertNotNull(client->server, "Server is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );

  if(client->state == SERVER_CLIENT_STATE_DISCONNECTED) {
    return error("Client is disconnected");
  }
  ssize_t sent = send(client->networked.socket, data, len, 0);
  if(sent < 0) return error("Failed to send data");
  return ERROR_OK;
}

errorret_t networkedServerClientWritePacket(
  serverclient_t * client,
  const packet_t *packet
) {
  assertNotNull(packet, "Packet is NULL");
  assertTrue(packet->type != PACKET_TYPE_INVALID, "Packet type is INVALID");
  assertTrue(packet->length > 0, "Packet length is 0");
  assertTrue(
    packet->length <= sizeof(packetdata_t),
    "Packet length is too large (1)"
  );

  size_t fullSize = sizeof(packet_t) - sizeof(packet->data) + packet->length;
  assertTrue(fullSize <= sizeof(packet_t), "Packet size is too large (2)");
  return networkedServerClientWrite(client, (const uint8_t *)packet, fullSize);
}

void * networkedServerClientReadThread(void *arg) {
  assertNotNull(arg, "Client is NULL");
  assertNotMainThread("Client thread must not be main thread");

  serverclient_t *client = (serverclient_t *)arg;
  char_t buffer[1024];
  ssize_t read;
  errorret_t err;
  packet_t packet;

  assertNotNull(client->server, "Server is NULL");
  assertTrue(
    client->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );
  assertTrue(
    client->state == SERVER_CLIENT_STATE_ACCEPTING,
    "Client is not accepting?"
  );

  // First message from the client should always be "DUSK|VERSION" to match
  // the server version.
  {
    const char_t *expecting = "DUSK|"DUSK_VERSION;
    read = networkedServerClientRead(client, buffer, sizeof(buffer));
    if(read <= 0) {
      packetDisconnectCreate(&packet, PACKET_DISCONNECT_REASON_INVALID_VERSION);
      err = networkedServerClientWritePacket(client, &packet);
      networkedServerClientCloseOnThread(client, "Failed to receive version");
      return NULL;
    }

    buffer[read] = '\0'; // Null-terminate the string
    if(strncmp(buffer, expecting, strlen(expecting)) != 0) {
      packetDisconnectCreate(&packet, PACKET_DISCONNECT_REASON_INVALID_VERSION);
      err = networkedServerClientWritePacket(client, &packet);
      networkedServerClientCloseOnThread(client, "Invalid version");
      return NULL;
    }
  }

  // Send DUSK back!
  packetWelcomeCreate(&packet);
  err = networkedServerClientWritePacket(client, &packet);
  if(err != ERROR_OK) {
    networkedServerClientCloseOnThread(client, "Failed to send welcome message");
    return NULL;
  }

  // Client is connected.
  client->state = SERVER_CLIENT_STATE_CONNECTED;

  // Start the write thread after the handshake
  int32_t ret = pthread_create(
    &client->networked.writeThread,
    NULL,
    networkedServerClientWriteThread,
    client
  );
  if(ret != 0) {
    networkedServerClientCloseOnThread(client, "Failed to create write thread");
    return NULL;
  }

  // Start listening for packets.
  while(client->state == SERVER_CLIENT_STATE_CONNECTED) {
    pthread_mutex_lock(&client->networked.readLock);

    pthread_mutex_unlock(&client->networked.readLock);
  }

  pthread_mutex_lock(&client->networked.readLock);
  client->state = SERVER_CLIENT_STATE_DISCONNECTED;
  pthread_mutex_unlock(&client->networked.readLock);

  return NULL;
}

void * networkedServerClientWriteThread(void *arg) {
  assertNotNull(arg, "Client is NULL");
  assertNotMainThread("Client thread must not be main thread");
  assertTrue(
    ((serverclient_t *)arg)->server->type == SERVER_TYPE_NETWORKED,
    "Server is not networked"
  );
  assertTrue(
    ((serverclient_t *)arg)->state == SERVER_CLIENT_STATE_CONNECTED,
    "Client is not connected"
  );

  serverclient_t *client = (serverclient_t *)arg;

  while(client->state == SERVER_CLIENT_STATE_CONNECTED) {
    pthread_mutex_lock(&client->networked.writeLock);

    pthread_mutex_unlock(&client->networked.writeLock);
  }

  return NULL;
}