#include "AccelStructureBuilder.hpp"


MTL::AccelerationStructure* newPrimitiveAccelerationStructure(MTL::Device* pDevice,
                                                             MTL::CommandQueue* pCommandQueue,
                                                             const NS::Array* geometryDescriptors)
{
    assert(pDevice);
    assert(pCommandQueue);
    assert(geometryDescriptors);
    
    MTL::AccelerationStructure* pPrimitiveAccelerationStructure = nullptr;
    
    // Build primitive acceleration structure
    {
        
        auto pPrimitiveDescriptors = NS::TransferPtr(MTL::PrimitiveAccelerationStructureDescriptor::alloc()->init());
        pPrimitiveDescriptors->setGeometryDescriptors(geometryDescriptors);
        
        auto primitiveSizes = pDevice->accelerationStructureSizes(pPrimitiveDescriptors.get());
        pPrimitiveAccelerationStructure = pDevice->newAccelerationStructure(primitiveSizes.accelerationStructureSize);
    
    
        auto pScratch = NS::TransferPtr(pDevice->newBuffer(primitiveSizes.buildScratchBufferSize, MTL::ResourceStorageModePrivate));
        
        MTL::CommandBuffer* pCmd = pCommandQueue->commandBuffer();
        MTL::AccelerationStructureCommandEncoder* pBuildEnc = pCmd->accelerationStructureCommandEncoder();
        pBuildEnc->buildAccelerationStructure(pPrimitiveAccelerationStructure,
                                              pPrimitiveDescriptors.get(),
                                              pScratch.get(), 0);
        pBuildEnc->endEncoding();
        pCmd->commit();
        pCmd->waitUntilCompleted();
    }
    
    return pPrimitiveAccelerationStructure;
}

MTL::AccelerationStructure* newInstanceAcceleartionStructure(MTL::Device* pDevice,
                                                             MTL::CommandQueue* pCommandQueue,
                                                             const std::vector<MTL::AccelerationStructure*>& primitiveStructures,
                                                             const std::vector<MTL::AccelerationStructureInstanceDescriptor>& instances)
{
    assert(pDevice);
    assert(pCommandQueue);
    assert(primitiveStructures.size() > 0);
    assert(instances.size() > 0);
    
    MTL::AccelerationStructure* pInstanceAccelerationStructure = nullptr;
    
    auto primitiveAccelerationStructures = NS::TransferPtr((NS::Array *)CFArrayCreate(CFAllocatorGetDefault(),
                                                                                      (const void **)&(primitiveStructures[0]),
                                                                                      primitiveStructures.size(),
                                                                                      &kCFTypeArrayCallBacks));
    
    const size_t instanceDataSize = sizeof(MTL::AccelerationStructureInstanceDescriptor) * instances.size();
    
    auto pInstanceAccelDescriptor = NS::TransferPtr(MTL::InstanceAccelerationStructureDescriptor::alloc()->init());
    pInstanceAccelDescriptor->setInstancedAccelerationStructures(primitiveAccelerationStructures.get());
    pInstanceAccelDescriptor->setInstanceCount( instances.size() );
    
    auto pInstanceBuffer = NS::TransferPtr(pDevice->newBuffer(instanceDataSize, MTL::ResourceStorageModeShared));
    memcpy(pInstanceBuffer->contents(), instances.data(), instanceDataSize);
    
    pInstanceAccelDescriptor->setInstanceDescriptorBuffer(pInstanceBuffer.get());
    
    auto instanceSizes = pDevice->accelerationStructureSizes(pInstanceAccelDescriptor.get());
    pInstanceAccelerationStructure = pDevice->newAccelerationStructure(instanceSizes.accelerationStructureSize);
    
    auto pScratch = NS::TransferPtr(pDevice->newBuffer(instanceSizes.buildScratchBufferSize, MTL::ResourceStorageModePrivate));
    
    MTL::CommandBuffer* pCmd = pCommandQueue->commandBuffer();
    MTL::AccelerationStructureCommandEncoder* pEnc = pCmd->accelerationStructureCommandEncoder();
    pEnc->buildAccelerationStructure(pInstanceAccelerationStructure,
                                     pInstanceAccelDescriptor.get(),
                                     pScratch.get(), 0);
    pEnc->endEncoding();
    pCmd->commit();
    pCmd->waitUntilCompleted();
    
    return pInstanceAccelerationStructure;
}

AccelerationStructureWithResources newAccelerationStructure(MTL::Device* pDevice,
                                                            MTL::CommandQueue* pCommandQueue,
                                                            const NS::Array* geometryDescriptors,
                                                            const std::vector<MTL::AccelerationStructureInstanceDescriptor>& instances)
{
    assert(pDevice);
    assert(pCommandQueue);
    assert(geometryDescriptors);
    assert(instances.size() > 0);
    
    MTL::AccelerationStructure* pPrimitiveAccelerationStructure = newPrimitiveAccelerationStructure(pDevice, pCommandQueue, geometryDescriptors);
    MTL::AccelerationStructure* pInstanceAccelerationStructure  = newInstanceAcceleartionStructure(pDevice, pCommandQueue, {pPrimitiveAccelerationStructure}, instances);
    
    // Collect indirect resources (these need to be marked resident by the renderer)
    NS::Array* pIndirectResources = (NS::Array *)CFArrayCreate(CFAllocatorGetDefault(),
                                                               (const void **)&pPrimitiveAccelerationStructure,
                                                               1,
                                                               &kCFTypeArrayCallBacks);
    
    return AccelerationStructureWithResources{
        .pAccelStructure = pInstanceAccelerationStructure,
        .pIndirectResources = pIndirectResources
    };
}

void releaseAccelerationStructureWithResources(AccelerationStructureWithResources* pAccelStructureWithResources)
{
    assert(pAccelStructureWithResources);
    pAccelStructureWithResources->pAccelStructure->release();
    pAccelStructureWithResources->pIndirectResources->release();
}
