// Copyright (c) 2023 Dominic Masters
// 
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

#pragma once
#include "display/shader/ShaderStage.hpp"
#include "display/shader/IShader.hpp"
#include "assert/assert.hpp"
#include "assert/assertgl.hpp"
#include "display/Color.hpp"
#include "display/Texture.hpp"

#include "ShaderParameter.hpp"
#include "ShaderStructure.hpp"

namespace Dawn {
  typedef GLuint shadertexturebinding_t;

  enum class ShaderOpenGLVariant {
    GLSL_330_CORE
  };

  template<typename T>
  class Shader : public IShader<T> {
    private:
      std::vector<std::shared_ptr<ShaderStage>> stages;
      std::vector<struct ShaderParameter> parameters;
      std::vector<struct IShaderStructure> structures;
      enum ShaderOpenGLVariant variant;

      GLuint shaderProgram = -1;

    protected:
      /**
       * Overridable function to get the stages for the shader.
       * 
       * @param variant The variant of the shader to use.
       * @param rel The relative data to use.
       * @param stages The stages to add to.
       * @param parameters The parameters to add to.
       * @param structures The structures to add to.
       */
      virtual void getStages(
        const enum ShaderOpenGLVariant variant,
        const T *rel,
        std::vector<std::shared_ptr<ShaderStage>> &stages,
        std::vector<struct ShaderParameter> &parameters,
        std::vector<struct IShaderStructure> &structures
      ) = 0;

    public:
      /**
       * Initializes the shader, this needs to be called before the shader can
       * be used.
       */
      void init() override {
        // Determine which kind of OpenGL shader to use.
        variant = ShaderOpenGLVariant::GLSL_330_CORE;

        // Now get the stages
        T dummy;
        this->getStages(
          variant,
          &dummy,
          stages,
          parameters,
          structures
        );

        // Create the shader program
        shaderProgram = glCreateProgram();
        assertNoGLError();

        // Attach all the stages
        for(auto stage : stages) {
          glAttachShader(shaderProgram, stage->id);
          assertNoGLError();
        }

        // Link and verify the program
        glLinkProgram(shaderProgram);
        assertNoGLError();

        GLint status;
        glGetProgramiv(shaderProgram, GL_LINK_STATUS, &status);
        assertNoGLError();
        assertTrue(status == GL_TRUE, "Failed to link shader program.");

        std::vector<std::string> uniformNames;
        GLint numUniforms = 0;

        // Get the number of active uniforms
        glGetProgramiv(shaderProgram, GL_ACTIVE_UNIFORMS, &numUniforms);
        assertNoGLError();

        // Iterate through each uniform
        // for (GLint i = 0; i < numUniforms; ++i) {
        //   char name[256];
        //   GLsizei length;
        //   GLint size;
        //   GLenum type;

        //   // Get the uniform name
        //   glGetActiveUniform(shaderProgram, i, sizeof(name), &length, &size, &type, name);
        //   assertNoGLError();
        //   std::cout << "Uniform: " << i << ":" << name << std::endl;
        //   // uniformNames.push_back(std::string(name));
        // }

        // Map parameters correctly.
        std::for_each(
          parameters.begin(),
          parameters.end(),
          [&](struct ShaderParameter &param) {
            // Correct offset
            param.offset = param.offset - (size_t)(&dummy);
            param.location = glGetUniformLocation(
              shaderProgram,
              param.name.c_str()
            );
            assertNoGLError();
            assertTrue(
              param.location != -1,
              "Failed to get location for parameter %s.",
              param.name.c_str()
            );
          }
        );

        // Map structures
        std::for_each(
          structures.begin(),
          structures.end(),
          [&](struct IShaderStructure &structure) {
            structure.offset = structure.offset - (size_t)(&dummy);
            structure.location = glGetUniformBlockIndex(
              shaderProgram,
              structure.structureName.c_str()
            );
            assertNoGLError();
            assertTrue(
              structure.location != -1,
              "Failed to get location for structure %s.",
              structure.structureName.c_str()
            );

            // Create the buffer
            glGenBuffers(1, &structure.buffer);
          }
        );

        this->bind();
      }

      /**
       * Binds the shader as the current one, does not upload any data, somewhat
       * relies on something else uploading the data.
       */
      void bind() override {
        glUseProgram(shaderProgram);
        assertNoGLError();
      }

      /**
       * Uploads the data to the GPU.
       */
      void upload() override {
        switch(this->variant) {
          case ShaderOpenGLVariant::GLSL_330_CORE:
            for(auto param : parameters) {
              void *value = (void*)(
                ((size_t)&this->data) + param.offset
              );

              switch(param.type) {
                case ShaderParameterType::MAT4: {
                  glm::mat4 *matrix = (glm::mat4 *)value;
                  if(param.count != 1) {
                    assertUnreachable("I haven't implemented multiple mat4s");
                  }
                  glUniformMatrix4fv(
                    param.location, 1, GL_FALSE, glm::value_ptr(*matrix)
                  );
                  break;
                }
                
                case ShaderParameterType::COLOR: {
                  auto color = (Color *)value;
                  glUniform4fv(
                    param.location,
                    param.count,
                    (GLfloat*)value
                  );
                  break;
                }

                case ShaderParameterType::BOOLEAN: {
                  glUniform1iv(param.location, param.count, (GLint*)value);
                  break;
                }

                case ShaderParameterType::TEXTURE: {
                  glUniform1iv(param.location, param.count, (GLint*)value);
                  break;
                }

                default: {
                  assertUnreachable("Unsupported ShaderParameterType");
                }
              }

              assertNoGLError();
            }
            break;

          default:
            assertUnreachable("Unsupported ShaderOpenGLVariant");
        }

        // Upload structures
        for(auto structure : structures) {
          switch(structure.structureType) {
            case ShaderOpenGLStructureType::STD140: {
              // Upload the data
              glBindBuffer(GL_UNIFORM_BUFFER, structure.buffer);
              assertNoGLError();
              glBindBufferBase(GL_UNIFORM_BUFFER, structure.location, structure.buffer); 
              assertNoGLError();
              glBufferData(
                GL_UNIFORM_BUFFER,
                structure.size * structure.count,
                (void*)((size_t)&this->data + (size_t)structure.offset),
                GL_STATIC_DRAW
              );
              assertNoGLError();
              break;
            }

            default:
              assertUnreachable("Unsupported ShaderOpenGLStructureType");
          }
        }
      }

      ~Shader() {
        // Delete the structures
        for(auto structure : structures) {
          assertTrue(structure.buffer != -1, "Invalid buffer.");
          glDeleteBuffers(1, &structure.buffer);
          assertNoGLError();
        }

        // Delete the shader program
        glDeleteProgram(shaderProgram);
        assertNoGLError();
      }
  };
}