// Copyright (c) 2023 Dominic Masters
// 
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

#include "display/shader/UIShader.hpp"
#include "util/Macro.hpp"

using namespace Dawn;

void UIShader::getStages(
  const enum ShaderOpenGLVariant variant,
  const struct UIShaderData *rel,
  std::vector<std::shared_ptr<ShaderStage>> &stages,
  std::vector<struct ShaderParameter> &parameters,
  std::vector<struct IShaderStructure> &structures
) {
  // Stages
  std::shared_ptr<ShaderStage> vertex;
  std::shared_ptr<ShaderStage> fragment;

  switch(variant) {
    case ShaderOpenGLVariant::GLSL_330_CORE:
      vertex = std::make_shared<ShaderStage>(
        ShaderStageType::VERTEX,
        "#version 330 core\n"
        "layout (location = 0) in vec3 aPos;\n"
        "layout (location = 1) in vec2 aTexCoord;\n"
        "uniform mat4 u_Projection;\n"
        "uniform mat4 u_View;\n"
        "uniform mat4 u_Model;\n"
        "struct UIShaderQuad {\n"
          "vec4 quad;\n"
          "vec4 uv;\n"
          "vec4 color;\n"
          "vec4 style;\n"
        "};\n"
        "layout (std140) uniform ub_Quad {\n"
          "UIShaderQuad quads[" MACRO_STRINGIFY(UI_SHADER_QUAD_COUNT) "];\n"
        "};\n"
        "out vec2 v_TextCoord;\n"
        "out vec4 v_Color;\n"
        "out vec4 v_Style;\n"
        "void main() {\n"
          "vec4 pos;\n"
          "vec2 coord;\n"
          "int index = int(aPos.z);\n"
          "int quadIndex = index / 4;\n"
          "int vertexIndex = index % 4;\n"
          "UIShaderQuad quad = quads[quadIndex];\n"
          "if(vertexIndex == 0) {\n"
            "pos.x = quad.quad.x;\n"
            "pos.y = quad.quad.y;\n"
            "coord.x = quad.uv.x;\n"
            "coord.y = quad.uv.y;\n"
          "} else if(vertexIndex == 1) {\n"
            "pos.x = quad.quad.z;\n"
            "pos.y = quad.quad.y;\n"
            "coord.x = quad.uv.z;\n"
            "coord.y = quad.uv.y;\n"
          "} else if(vertexIndex == 2) {\n"
            "pos.y = quad.quad.w;\n"
            "pos.x = quad.quad.x;\n"
            "coord.x = quad.uv.x;\n"
            "coord.y = quad.uv.w;\n"
          "} else if(vertexIndex == 3) {\n"
            "pos.x = quad.quad.z;\n"
            "pos.y = quad.quad.w;\n"
            "coord.x = quad.uv.z;\n"
            "coord.y = quad.uv.w;\n"
          "}\n"
          "pos.z = 0;\n"
          "pos.w = 1;\n"
          "gl_Position = u_Projection * u_View * u_Model * pos;\n"
          "v_TextCoord = coord;\n"
          "v_Color = quad.color;\n"
          "v_Style = quad.style;\n"
        "}"
      );
      
      fragment = std::make_shared<ShaderStage>(
        ShaderStageType::FRAGMENT,
        "#version 330 core\n"
        "in vec2 v_TextCoord;\n"
        "in vec4 v_Color;\n"
        "in vec4 v_Style;\n"
        "uniform sampler2D u_Texture[" MACRO_STRINGIFY(UI_SHADER_TEXTURE_COUNT) "];\n"
        "out vec4 o_Color;\n" 
        "void main() {\n"
          "vec4 texColor = vec4(1, 1, 1, 1);\n"
          "int vStyle = int(round(v_Style[0]));\n"
          "int vTextInd = int(round(v_Style[1]));\n"
          "switch(vTextInd) {\n"
            "case -1:\n"
              "texColor = vec4(1, 1, 1, 1);\n"
              "break;\n"
            "case 0:\n"
              "texColor = texture(u_Texture[0], v_TextCoord);\n"
              "break;\n"
            "case 1:\n"
              "texColor = texture(u_Texture[1], v_TextCoord);\n"
              "break;\n"
            "case 2:\n"
              "texColor = texture(u_Texture[2], v_TextCoord);\n"
              "break;\n"
            "case 3:\n"
              "texColor = texture(u_Texture[3], v_TextCoord);\n"
              "break;\n"
            "case 4:\n"
              "texColor = texture(u_Texture[4], v_TextCoord);\n"
              "break;\n"
            "case 5:\n"
              "texColor = texture(u_Texture[5], v_TextCoord);\n"
              "break;\n"
          "}\n"
          "switch(vStyle) {\n"
            "case 0:\n"
              "o_Color = texColor * v_Color;\n"
              "break;\n"
            "case 1:\n"
              "o_Color.rgb = v_Color.rgb;\n"
              "o_Color.a = texColor.r * v_Color.a;\n"
              "break;\n"
          "}\n"
        "}\n"
      );
      break;

    default:
      assertUnreachable("Unsupported ShaderOpenGLVariant");
  }

  // Add stages
  stages.push_back(vertex);
  stages.push_back(fragment);

  // Parameters
  parameters.push_back(ShaderParameter(
    "u_Projection",
    &rel->projection,
    ShaderParameterType::MAT4
  ));

  parameters.push_back(ShaderParameter(
    "u_View",
    &rel->view,
    ShaderParameterType::MAT4
  ));

  parameters.push_back(ShaderParameter(
    "u_Model",
    &rel->model,
    ShaderParameterType::MAT4
  ));

  parameters.push_back(ShaderParameter(
    "u_Texture",
    &rel->textures,
    ShaderParameterType::TEXTURE,
    UI_SHADER_TEXTURE_COUNT
  ));
  
  structures.push_back(ShaderStructure<struct UIShaderQuad>(
    "ub_Quad",
    &rel->quads,
    ShaderOpenGLStructureType::STD140,
    [&](const struct UIShaderQuad &rel, std::vector<struct ShaderParameter> &parameters) {
      parameters.push_back(ShaderParameter(
        "quad",
        &rel.quad,
        ShaderParameterType::VEC4
      ));

      parameters.push_back(ShaderParameter(
        "uv",
        &rel.uv,
        ShaderParameterType::VEC4
      ));

      parameters.push_back(ShaderParameter(
        "color",
        &rel.color,
        ShaderParameterType::COLOR
      ));

      parameters.push_back(ShaderParameter(
        "style",
        &rel.style,
        ShaderParameterType::VEC4
      ));
    },
    UI_SHADER_QUAD_COUNT
  ));
}