{.experimental: "codeReordering".}
{.deadCodeElim: on.}

import vulkan 
, buffer 
, swapchain
, depth_stencil
, render_pass
, render_passes
, uniform_buffer
, vulkan_utils

import std/os

import glm 

#import ../drawable/shape_types

type 
 Render_Update_Stage* = enum
  Never, Always, Pre, Normal, Post, Render

 Vertex* = object
  position*: array[3,float32]
  color*: array[3, float32]
  #texture_coordinates*: array[3, float32]

 Graphics_Pipeline_Obj = object of RootObj
  #Mode mode
  #Shader* shader

  pipeline_stage*: Render_Update_Stage
  shader_stages*: seq[string]
  #shader_defines*: Shader_Define
  #depth*: Depth
  pushDescriptors*: bool 
  dynamic_states*: seq[VkDynamicState]
  shader_modules*: seq[VkShaderModule]
  shaders*: seq[string]
  pipeline*: VkPipeline 
  pipeline_layout*: VkPipelineLayout   
  pipeline_layout_create_info*: VkPipelineLayoutCreateInfo
  pipelineBindPoint*: VkPipelineBindPoint
  descriptor_pool*: VkDescriptorPool 
  descriptor_set*: VkDescriptorSet 
  descriptor_set_layout*: VkDescriptorSetLayout 
  depth_stencil*: Depth_Stencil 
  topology*: VkPrimitiveTopology
  polygonMode*: VkPolygonMode
  cullMode*: VkCullModeFlags
  frontFace*: VkFrontFace
  render_pass*: Render_Pass

  blend_attachment_states*: array[1, VkPipelineColorBlendAttachmentState]
  pipeline_shader_stages*: seq[VkPipelineShaderStageCreateInfo]
  vertexInputStateCreateInfo*: VkPipelineVertexInputStateCreateInfo
  inputAssemblyState*: VkPipelineInputAssemblyStateCreateInfo
  rasterizationState*: VkPipelineRasterizationStateCreateInfo
  colourBlendState*: VkPipelineColorBlendStateCreateInfo
  depthStencilState*: VkPipelineDepthStencilStateCreateInfo
  viewportState*: VkPipelineViewportStateCreateInfo
  multisampleState*: VkPipelineMultisampleStateCreateInfo
  dynamicState*: VkPipelineDynamicStateCreateInfo
  tessellationState*: VkPipelineTessellationStateCreateInfo
  render_pass_begin_info  *: VkRenderPassBeginInfo
 
 Graphics_Pipeline* = ref object of Graphics_Pipeline_Obj


proc a_graphics_pipeline*( vk_device: VkDevice
                         , swapchain: Swapchain
                         , descriptor_buffer_info: VkDescriptorBufferInfo
                         , pipeline_cache: var VkPipelineCache
                         , gpu_memory_properties: VkPhysicalDeviceMemoryProperties
                         , shaders: seq[string] = @[]
                         , hollow: bool = false
                         , make_depth_stencil: bool = false 
                         ): Graphics_Pipeline = 
 
 var shader_stages: seq[VkPipelineShaderStageCreateInfo]
 shader_stages.setLen shaders.len

 result = Graphics_Pipeline( descriptor_pool: create_descriptor_pool( vk_device )
                           , descriptor_set_layout: create_descriptor_set_layout( vk_device )
                           , render_pass: create_basic_shape_render_pass( vk_device
                                                                        , swapchain.color_format
                                                                        , swapchain.depth_format
                                                                        )
                           )
 if make_depth_stencil: 
  result.depth_stencil = create_depth_stencil( vk_device
                                            , swapchain.color_format
                                            , swapchain.depth_format
                                            , swapchain.current_extent
                                            , gpu_memory_properties
                                            ) 
 result.pipeline_layout = create_pipeline_layout( vk_device
                                                , result.descriptor_set_layout
                                                , 0
                                                )
 
 result.descriptor_set = create_descriptor_sets( vk_device 
                                               , result.descriptor_set_layout
                                               , result.descriptor_pool
                                               , descriptor_buffer_info
                                               ) 
 var 
  topo: VkPrimitiveTopology = VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST
  pMode: VkPolygonMode = VK_POLYGON_MODE_FILL
 
 var 
  pipeline_info = VkGraphicsPipelineCreateInfo(
     sType: VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO
     , layout: result.pipeline_layout
     , render_pass: result.render_pass.vk_handle
      #subpass: 0,
      #basePipelineHandle: VkPipeline(VK_NULL_HANDLE), # optional
      #basePipelineIndex: -1, # optional
    )
  
  #Input assembly state describes how primitives are assembled
  #This pipeline will assemble vertex data as a triangle lists (though we only use one triangle)
  inputAssembly = VkPipelineInputAssemblyStateCreateInfo(
     sType: VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO
     , topology: topo
      #primitiveRestartEnable: VkBool32(VK_FALSE),
    )
  
  rasterizer = VkPipelineRasterizationStateCreateInfo(
     sType: VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO
     , depthClampEnable: VkBool32(VK_FALSE)
     , rasterizerDiscardEnable: VkBool32(VK_FALSE)
     , polygonMode: pMode
     , lineWidth: 1.float32
     , cullMode: VkCullModeFlags VK_CULL_MODE_NONE
     , frontFace: VK_FRONT_FACE_COUNTER_CLOCKWISE
     , depthBiasEnable: VkBool32(VK_FALSE)
    #  , depthBiasConstantFactor: 0f # optional
    #  , depthBiasClamp: 0f # optional
    #  , depthBiasSlopeFactor: 0f # optional
    )
  
  #[
    VK_BLEND_FACTOR_ZERO = 0
    VK_BLEND_FACTOR_ONE = 1
    VK_BLEND_FACTOR_SRC_COLOR = 2
    VK_BLEND_FACTOR_ONE_MINUS_SRC_COLOR = 3
    VK_BLEND_FACTOR_DST_COLOR = 4
    VK_BLEND_FACTOR_ONE_MINUS_DST_COLOR = 5
    VK_BLEND_FACTOR_SRC_ALPHA = 6
    VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA = 7
    VK_BLEND_FACTOR_DST_ALPHA = 8
    VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA = 9
    VK_BLEND_FACTOR_CONSTANT_COLOR = 10
    VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_COLOR = 11
    VK_BLEND_FACTOR_CONSTANT_ALPHA = 12
    VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_ALPHA = 13
    VK_BLEND_FACTOR_SRC_ALPHA_SATURATE = 14
    VK_BLEND_FACTOR_SRC1_COLOR = 15
    VK_BLEND_FACTOR_ONE_MINUS_SRC1_COLOR = 16
    VK_BLEND_FACTOR_SRC1_ALPHA = 17
    VK_BLEND_FACTOR_ONE_MINUS_SRC1_ALPHA = 18
  ]#
  # Assume Alpha Blending 
  colorBlendAttachment = VkPipelineColorBlendAttachmentState(
      blendEnable:         VKBool32 VK_TRUE 
    , srcColorBlendFactor: VkBlendFactor VK_BLEND_FACTOR_SRC_ALPHA
    , dstColorBlendFactor: VkBlendFactor VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA
    , colorBlendOp:        VkBlendOp VK_BLEND_OP_ADD
    , srcAlphaBlendFactor: VkBlendFactor VK_BLEND_FACTOR_ONE
    , dstAlphaBlendFactor: VkBlendFactor VK_BLEND_FACTOR_ZERO
    , alphaBlendOp:        VkBlendOp VK_BLEND_OP_ADD
    , colorWriteMask:      VkColorComponentFlags VK_COLOR_COMPONENT_R_BIT.ord or 
                                                 VK_COLOR_COMPONENT_G_BIT.ord or 
                                                 VK_COLOR_COMPONENT_B_BIT.ord or 
                                                 VK_COLOR_COMPONENT_A_BIT.ord
    
    )
  
  colorBlending = VkPipelineColorBlendStateCreateInfo(
     sType: VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO
    , logicOpEnable: VkBool32(VK_FALSE)
    , logicOp: VK_LOGIC_OP_COPY # optional
    , attachmentCount: 1
    , pAttachments: colorBlendAttachment.addr
    , blendConstants: [0f, 0f, 0f, 0f] # optional
    )
  
  # Enable dynamic states
  # Most states are baked into the pipeline, but there are still a few dynamic states that can be changed within a command buffer
  # To be able to change these we need do specify which dynamic states VkPipelineViewportStateCreateInfoVkPipelineViewportStateCreateInfo be changed using this pipeline. Their actual states are set later on in the command buffer.
  # For this example we will set the viewport and scissor using dynamic states
  dynamicStateEnables: seq[VkDynamicSTate] = @[ VK_DYNAMIC_STATE_VIEWPORT
                                              , VK_DYNAMIC_STATE_SCISSOR
                                              , VK_DYNAMIC_STATE_LINE_WIDTH
                                              ]

  dynamicState = VkPipelineDynamicStateCreateInfo(
   sType: VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO
   , pDynamicStates: addr dynamicStateEnables[0]
   , dynamicStateCount: dynamicStateEnables.len.uint32
   )
  
  # Depth and stencil state containing depth and stencil compare and test operations
  # We only use depth tests and want depth tests and writes to be enabled and compare with less or equal
  depthStencilState = VkPipelineDepthStencilStateCreateInfo(
    sType: VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO
    , depthTestEnable: VkBool32 VK_TRUE
    , depthWriteEnable: VkBool32 VK_TRUE
    , depthCompareOp: VK_COMPARE_OP_LESS_OR_EQUAL
    , depthBoundsTestEnable: VkBool32 VK_FALSE
    , back: VkStencilOpState( failOp: VK_STENCIL_OP_KEEP
                            , passOp: VK_STENCIL_OP_KEEP
                            , compareOp: VK_COMPARE_OP_ALWAYS
                            )
    , front: VkStencilOpState( failOp: VK_STENCIL_OP_KEEP
                            , passOp: VK_STENCIL_OP_KEEP
                            , compareOp: VK_COMPARE_OP_ALWAYS
                            )
    , stencilTestEnable: VkBool32 VK_FALSE
  )
  
  # Multi sampling state
  # This example does not make use of multi sampling (for anti-aliasing), the state must still be set and passed to the pipeline
  multisampleState = VkPipelineMultisampleStateCreateInfo(
    sType: VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO
    , rasterizationSamples: VK_SAMPLE_COUNT_1_BIT
    , pSampleMask: nil
  )

  # Vertex input descriptions
  # Specifies the vertex input parameters for a pipeline

  # Vertex input binding
  # This example uses a single vertex input binding at binding point 0 (see vkCmdBindVertexBuffers)
  vertexInputBinding = VkVertexInputBindingDescription(
    binding: 0
    , stride: uint32 sizeof(Vertex)
    , inputRate: VK_VERTEX_INPUT_RATE_VERTEX
  )

  # Input attribute bindings describe shader attribute locations and memory layouts
  vertex_input_attributes: array[2,VkVertexInputAttributeDescription]

 # Attribute location 0: Position
 vertex_input_attributes[0].binding = 0
 vertex_input_attributes[0].location = 0
 vertex_input_attributes[0].format = VK_FORMAT_R32G32B32_SFLOAT
 vertex_input_attributes[0].offset = uint32 offsetof(Vertex, position)
 
 # Attribute location 1: Color
 vertex_input_attributes[1].binding = 0
 vertex_input_attributes[1].location = 1
 vertex_input_attributes[1].format = VK_FORMAT_R32G32B32_SFLOAT
 vertex_input_attributes[1].offset = offsetof(Vertex, color).uint32
 
 #[
 vertex_input_attributes[2].location = 2
 vertex_input_attributes[2].format = VK_FORMAT_R32G32B32_SFLOAT
 vertex_input_attributes[2].offset = offsetof(Vertex, texture_coordinates).uint32
 ]#
 
 var 
  vertex_input_state = 
   VkPipelineVertexInputStateCreateInfo( sType: VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO
                                       , vertexBindingDescriptionCount: 1
                                       , pVertexBindingDescriptions: addr vertexInputBinding
                                       , vertexAttributeDescriptionCount: uint32 vertex_input_attributes.len
                                       , pVertexAttributeDescriptions: addr vertex_input_attributes[0]
                                       ) 
  viewport = VkViewport( 
    x: 0f, y: 0
    , width:   swapchain.current_extent.width.float32
    ,  height: swapchain.current_extent.height.float32
    ,  minDepth: 0f
    ,  maxDepth: 1f
    )
  
  scissor = VkRect2D( offset: VkOffset2D(x: 0, y: 0)
                    , extent: swapchain.current_extent,
                    )
                    
  viewportState = VkPipelineViewportStateCreateInfo(
     sType: VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO
     , viewportCount: 1
    , pViewports: viewport.addr
     , scissorCount: 1
     , pScissors: scissor.addr
     , flags: VkPipelineViewportStateCreateFlags 0
    )
  
 # Shaders
 if shaders.len != 0:
  var 
   vertShaderCode = readFile shaders[0]
   fragShaderCode = readFile shaders[1]
 
  # Vertex shader
  shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
  shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT
  shaderStages[0].module = vk_device.a_shader_module vertShaderCode
  # Main entry point for the shader
  shaderStages[0].pName = "main"
 
  # Fragment shader
  shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
  shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT
  shaderStages[1].module = vk_device.a_shader_module fragShaderCode
  # Main entry point for the shader
  shaderStages[1].pName = "main"
  
  # Set pipeline shader stage info
  pipeline_info.stageCount = shaderStages.len.uint32
  pipeline_info.pStages = addr shaderStages[0]
  
 # Assign the pipeline states to the pipeline creation info structure
 pipeline_info.pVertexInputState = addr vertex_input_state
 pipeline_info.pInputAssemblyState = addr inputAssembly
 pipeline_info.pRasterizationState = addr rasterizer
 pipeline_info.pColorBlendState = addr colorBlending
 pipeline_info.pMultisampleState = addr multisampleState
 pipeline_info.pViewportState = addr viewportState
 pipeline_info.pDepthStencilState = addr depthStencilState
 pipeline_info.render_pass = result.render_pass.vk_handle
 pipeline_info.pDynamicState = addr dynamicState

 # Create rendering pipeline using the specified states
 discard vkCreateGraphicsPipelines( vk_device
                                  , pipelineCache
                                  , 1
                                  , addr pipeline_info
                                  , nil
                                  , addr result.pipeline
                                  )
 
 # Shader modules are no longer needed once the graphics pipeline has been created
 if shaders.len != 0: 
  for shader_stage in shaderStages.mitems:
   vkDestroyShaderModule( vk_device
                        , shader_stage.module
                        , nil
                        )

proc a_shader_module( device: VkDevice
                    , code: string
                    ): VkShaderModule  = 
 
 var createInfo = VkShaderModuleCreateInfo( sType: VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
                                          , codeSize: cast[uint] (code.len)
                                          , pCode: cast[ptr uint32] ( addr code[0])
                                          )
  
 assert vkCreateShaderModule( device
                             , addr createInfo
                             , nil
                             , addr result
                             ) == VK_SUCCESS
  
 result

# TODO: currently only assumes a single vertex <-> fragment pair. (total of 2 shaders)
proc the_shader_stages( vk_device: VkDevice
                      , shaders: seq[string]
                      ): seq[VkPipelineShaderStageCreateInfo] =  
 
 for shader in shaders:
  assert( shader.fileExists()
        , "shader file path doesn't exist: " & shader
        )

 result = @[ 
    VkPipelineShaderStageCreateInfo( sType: VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
                                                , pNext: nil
                                                , flags: VkPipelineShaderStageCreateFlags 0
                                                , stage: VK_SHADER_STAGE_VERTEX_BIT
                                                , module: a_shader_module( vk_device
                                                                         , readFile shaders[0]
                                                                         )
                                                , pName: "main"
                                                , pSpecializationInfo: nil
                                                )
   
   , VkPipelineShaderStageCreateInfo( sType: VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
                                    , pNext: nil
                                    , flags: VkPipelineShaderStageCreateFlags 0
                                    , stage: VK_SHADER_STAGE_FRAGMENT_BIT
                                    , module: a_shader_module( vk_device
                                                             , readFile shaders[1]
                                                             )
                                    , pName: "main"
                                    , pSpecializationInfo: nil
                                    )
        ]

# TODO
# currently any shape > 6 sides is counted as an Ngon
# weird behavior when hollow is true, in the method used to render the verts
 
proc create_descriptor_set_layout*(  vk_device: VkDevice
                                  ): VkDescriptorSetLayout = 

  # Setup layout of descriptors used in this example
  # Basically connects the different shader stages to descriptors for binding uniform buffers, image samplers, etc.
  # So every shader binding should map to one descriptor set layout binding

  # Binding 0: Uniform buffer (Vertex shader)
  var
   layoutBinding: VkDescriptorSetLayoutBinding
   descriptorLayout: VkDescriptorSetLayoutCreateInfo


  layoutBinding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
  layoutBinding.descriptorCount = 1
  layoutBinding.stageFlags = VkShaderStageFlags VK_SHADER_STAGE_VERTEX_BIT.ord or 
                                                VK_SHADER_STAGE_FRAGMENT_BIT.ord
  layoutBinding.pImmutableSamplers = nil

  descriptorLayout.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO
  descriptorLayout.pNext = nil
  descriptorLayout.bindingCount = 1
  descriptorLayout.pBindings = addr layoutBinding

  discard vkCreateDescriptorSetLayout( vk_device
                                     , addr descriptorLayout
                                     , nil
                                     , addr result
                                     )

proc create_descriptor_sets*( vk_device: VkDevice 
                            , descriptor_set_layout: VkDescriptorSetLayout
                            , descriptor_pool: VkDescriptorPool
                            , descriptor_buffer_info: VkDescriptorBufferInfo 
                            ): VkDescriptorSet =
 
 var
  allocInfo = VkDescriptorSetAllocateInfo( sType: VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO
                                         , descriptorPool: descriptor_pool
                                         , descriptorSetCount: 1
                                         , pSetLayouts: addr descriptor_set_layout
                                         ) 

 assert vkAllocateDescriptorSets( vk_device
                                , addr allocInfo
                                , addr result
                                ) == VK_SUCCESS 

 # Update the descriptor set determining the shader binding points
 # For every binding point used in a shader there needs to be one
 # descriptor set matching that binding point

 var write_descriptor_sets: array[1,VkWriteDescriptorSet] = 
  [ VkWriteDescriptorSet( sType: VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET
                        , dstSet: result 
                        , descriptorCount: 1 
                        , descriptorType: VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
                        , pBufferInfo: cast[ptr ptr VkDescriptorBufferInfo] (addr descriptor_buffer_info) 
                        , dstBinding: 0
                        )
  ]
 
 vkUpdateDescriptorSets( vk_device
                       , 1
                       , addr write_descriptor_sets[0]
                       , 0
                       , nil
                       )

proc create_pipeline_layout( vk_device: VkDevice
                           , descriptor_set_layout: VkDescriptorSetLayout
                           , push_constant_size: uint32
                           , push_constant_range_count: uint32 = 0
                           ): VkPipelineLayout =  
 
 var 
  pPipelineLayoutCreateInfo: VkPipelineLayoutCreateInfo
  pushConstantRange = VkPushConstantRange( stageFlags: VkShaderStageFlags VK_SHADER_STAGE_VERTEX_BIT.ord or VK_SHADER_STAGE_FRAGMENT_BIT.ord
                                         , offset: 0
                                         , size : push_constant_size# uint32 sizeof Shape_Pushes
                                         )

 # Create the pipeline layout that is used to generate the rendering pipelines that are based on this descriptor set layout
 # In a more complex scenario you would have different pipeline layouts for different descriptor set layouts that could be reused
 pPipelineLayoutCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
 pPipelineLayoutCreateInfo.pNext = nil
 pPipelineLayoutCreateInfo.setLayoutCount = 1
 pPipelineLayoutCreateInfo.pSetLayouts = addr descriptor_set_layout
 pPipelineLayoutCreateInfo.pushConstantRangeCount = push_constant_range_count
 #pPipelineLayoutCreateInfo.pPushConstantRanges = addr pushConstantRange
  
 assert vkCreatePipelineLayout( vk_device
                              , addr pPipelineLayoutCreateInfo
                              , nil
                              , addr result 
                              ) == VK_SUCCESS
 result