// Copyright 2020 NVIDIA Corporation. All rights reserved
//
// The sample provides the generic workflow for querying various properties of metrics which are available as part of 
// the Profiling APIs. In this particular case we are querying for number of passes and collection method for a list of metrics.
//
// Number of passes : It gives the number of passes required for collection of the metric as some of the metric
// cannot be collected in single pass due to hardware or software limitation, we need to replay the exact same 
// set of GPU workloads multiple times.
//
// Collection method : It gives the source of the metric (HW or SW). Most of metric are provided by hardware but for
// some metric we have to instrument the kernel to collect the metric. Further these metrics cannot be combined with 
// any other metrics in the same pass as otherwise instrumented code will also contribute to the metric value.
//


#include <stdio.h>
#include <vector>
#include <string>
#include <cstring>
#include <iostream>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <Parser.h>
#include <Utils.h>
#include <nvperf_host.h>
#include <nvperf_cuda_host.h>
#include <cupti_profiler_target.h>
#include <cupti_target.h>

#define DRIVER_API_CALL(apiFuncCall)                                           \
do {                                                                           \
    CUresult _status = apiFuncCall;                                            \
    if (_status != CUDA_SUCCESS) {                                             \
        fprintf(stderr, "%s:%d: error: function %s failed with error %d.\n",   \
                __FILE__, __LINE__, #apiFuncCall, _status);                    \
        exit(-1);                                                              \
    }                                                                          \
} while (0)

#define CUPTI_API_CALL(apiFuncCall)                                            \
do {                                                                           \
    CUptiResult _status = apiFuncCall;                                         \
    const char *errstr;                                                        \
    cuptiGetResultString(_status, &errstr);                                    \
    if (_status != CUPTI_SUCCESS) {                                            \
        fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n",   \
                __FILE__, __LINE__, #apiFuncCall, errstr);                     \
        exit(-1);                                                              \
    }                                                                          \
} while (0)

#define FORMAT_METRIC_DETAILS(stream, metricName, numOfPasses, collectionMethod, isCSVformat)   \
    if(isCSVformat) {                                                                           \
        stream  << metricName  << ","                                                           \
                << numOfPasses << ","                                                           \
                << collectionMethod << "\n";                                                    \
    } else {                                                                                    \
        stream  << std::setw(80) << std::left << metricName  << "\t"                            \
                << std::setw(15) << std::left << numOfPasses << "\t"                            \
                << std::setw(15) << std::left << collectionMethod << "\n";                      \
    }

#define PRINT_METRIC_DETAILS(stream, outputStream, isCSVformat)                                         \
{                                                                                                       \
    FORMAT_METRIC_DETAILS(stream, "Metric Name", "Num of Passes", "Collection Method", isCSVformat)     \
    std::string metricName, numOfPasses, collectionMethod;                                              \
    while (outputStream >> metricName >> numOfPasses >> collectionMethod) {                             \
        FORMAT_METRIC_DETAILS(stream, metricName, numOfPasses, collectionMethod, isCSVformat)           \
    }                                                                                                   \
}

std::string GetMetricCollectionMethod(std::string metricName)
{
    const std::string SW_CHECK = "sass";
    if (metricName.find(SW_CHECK) != std::string::npos) {
        return "SW";
    }
    return "HW";
}

bool GetRawMetricRequests(NVPA_MetricsContext * pMetricsContext,
      std::string metricName,
      std::vector<NVPA_RawMetricRequest> & rawMetricRequests,
      std::vector<std::string> & pMetricDependencies) 
{
    std::string reqName;
    bool isolated = true;
    bool keepInstances = true;
                  
    NV::Metric::Parser::ParseMetricNameString(metricName, &reqName, &isolated, &keepInstances);
    /* Bug in collection with collection of metrics without instances, keep it to true*/
    keepInstances = true;
    NVPW_MetricsContext_GetMetricProperties_Begin_Params getMetricPropertiesBeginParams = { NVPW_MetricsContext_GetMetricProperties_Begin_Params_STRUCT_SIZE };
    getMetricPropertiesBeginParams.pMetricsContext = pMetricsContext;
    getMetricPropertiesBeginParams.pMetricName = reqName.c_str();
    RETURN_IF_NVPW_ERROR(false, NVPW_MetricsContext_GetMetricProperties_Begin(&getMetricPropertiesBeginParams));

    for (const char** ppMetricDependencies = getMetricPropertiesBeginParams.ppRawMetricDependencies; *ppMetricDependencies; ++ppMetricDependencies)
    {
        pMetricDependencies.push_back(*ppMetricDependencies);
    }
    NVPW_MetricsContext_GetMetricProperties_End_Params getMetricPropertiesEndParams = { NVPW_MetricsContext_GetMetricProperties_End_Params_STRUCT_SIZE };
    getMetricPropertiesEndParams.pMetricsContext = pMetricsContext;
    RETURN_IF_NVPW_ERROR(false, NVPW_MetricsContext_GetMetricProperties_End(&getMetricPropertiesEndParams));

    for (auto& rawMetricName : pMetricDependencies)
    {
        NVPA_RawMetricRequest metricRequest = { NVPA_RAW_METRIC_REQUEST_STRUCT_SIZE };
        metricRequest.pMetricName = rawMetricName.c_str();
        metricRequest.isolated = isolated;
        metricRequest.keepInstances = keepInstances;
        rawMetricRequests.push_back(metricRequest);
    }

    return true;
}

bool GetMetricDetails(std::string pMetricName, 
    std::string pChipName,
    NVPW_CUDA_MetricsContext_Create_Params pMetricsContextCreateParams, 
    std::stringstream& pOutputStream)
{
    NVPA_RawMetricsConfigOptions metricsConfigOptions = { NVPA_RAW_METRICS_CONFIG_OPTIONS_STRUCT_SIZE };
    metricsConfigOptions.activityKind = NVPA_ACTIVITY_KIND_PROFILER;
    metricsConfigOptions.pChipName = pChipName.c_str();
    NVPA_RawMetricsConfig* pRawMetricsConfig;
    RETURN_IF_NVPW_ERROR(false, NVPA_RawMetricsConfig_Create(&metricsConfigOptions, &pRawMetricsConfig));

    NVPW_RawMetricsConfig_BeginPassGroup_Params beginPassGroupParams = { NVPW_RawMetricsConfig_BeginPassGroup_Params_STRUCT_SIZE };
    beginPassGroupParams.pRawMetricsConfig = pRawMetricsConfig;
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_BeginPassGroup(&beginPassGroupParams));

    std::vector<NVPA_RawMetricRequest> rawMetricRequests;
    std::vector<std::string> metricDependencies;
    if (!GetRawMetricRequests(pMetricsContextCreateParams.pMetricsContext, pMetricName, rawMetricRequests, metricDependencies)) {
        printf("Error!! Failed to get raw metrics\n");
        return false;
    }

    NVPW_RawMetricsConfig_IsAddMetricsPossible_Params isAddMetricsPossibleParams = { NVPW_RawMetricsConfig_IsAddMetricsPossible_Params_STRUCT_SIZE };
    isAddMetricsPossibleParams.pRawMetricsConfig = pRawMetricsConfig;
    isAddMetricsPossibleParams.pRawMetricRequests = &rawMetricRequests[0];
    isAddMetricsPossibleParams.numMetricRequests = rawMetricRequests.size();
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_IsAddMetricsPossible(&isAddMetricsPossibleParams));
    
    NVPW_RawMetricsConfig_AddMetrics_Params addMetricsParams = { NVPW_RawMetricsConfig_AddMetrics_Params_STRUCT_SIZE };
    addMetricsParams.pRawMetricsConfig = pRawMetricsConfig;
    addMetricsParams.pRawMetricRequests = &rawMetricRequests[0];
    addMetricsParams.numMetricRequests = rawMetricRequests.size();
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_AddMetrics(&addMetricsParams));

    NVPW_RawMetricsConfig_EndPassGroup_Params endPassGroupParams = { NVPW_RawMetricsConfig_EndPassGroup_Params_STRUCT_SIZE };
    endPassGroupParams.pRawMetricsConfig = pRawMetricsConfig;
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_EndPassGroup(&endPassGroupParams));

    NVPW_RawMetricsConfig_GetNumPasses_Params rawMetricsConfigGetNumPassesParams = { NVPW_RawMetricsConfig_GetNumPasses_Params_STRUCT_SIZE };
    rawMetricsConfigGetNumPassesParams.pRawMetricsConfig = pRawMetricsConfig;
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_GetNumPasses(&rawMetricsConfigGetNumPassesParams));

    // No Nesting of ranges in case of CUPTI_AutoRange, in AutoRange
    // the range is already at finest granularity of every kernel Launch so numNestingLevels = 1
    size_t numNestingLevels = 1;
    size_t numIsolatedPasses = rawMetricsConfigGetNumPassesParams.numIsolatedPasses;
    size_t numPipelinedPasses = rawMetricsConfigGetNumPassesParams.numPipelinedPasses;
    size_t numOfPasses = numPipelinedPasses + numIsolatedPasses * numNestingLevels;
    std::string collectionMethod = GetMetricCollectionMethod(pMetricName);

    NVPW_RawMetricsConfig_Destroy_Params rawMetricsConfigDestroyParams = { NVPW_RawMetricsConfig_Destroy_Params_STRUCT_SIZE };
    rawMetricsConfigDestroyParams.pRawMetricsConfig = pRawMetricsConfig;
    RETURN_IF_NVPW_ERROR(false, NVPW_RawMetricsConfig_Destroy((NVPW_RawMetricsConfig_Destroy_Params*)&rawMetricsConfigDestroyParams));

    pOutputStream << pMetricName << " "
                  << numOfPasses << " "
                  << collectionMethod << "\n";

    return true;
}

int main(int argc, char* argv[])
{
    std::vector<std::string> metricNames;
    int deviceCount;

    int deviceNum = 0;
    std::string chipName;
    bool bIsCSVformat = false;
    char* metricName;
    std::string exportFileName;

    for (int i = 1; i < argc; ++i) 
    {
        char* arg = argv[i];
        if (strcmp(arg, "--help") == 0){
            printf("Usage: %s --device [device_num] --chip [chip name] --metrics [metric_names comma separated] --csv --file [filename]\n", argv[0]);
            return 0;
        } 
        
        if (strcmp(arg, "--device") == 0) {
            deviceNum = atoi(argv[i + 1]);
            i++;
        } else if (strcmp(arg, "--chip") == 0) {
            chipName = argv[i + 1];
            i++;
        } else if (strcmp(arg, "--metrics") == 0) {
            metricName = strtok(argv[i+1], ",");
            while (metricName != NULL)
            {
                metricNames.push_back(metricName);
                metricName = strtok(NULL, ",");
            }
            i++;
        } else if (strcmp(arg, "--csv") == 0) {
            bIsCSVformat = true;
        } else if (strcmp(arg, "--file") == 0) {
            exportFileName = argv[i + 1];
            i++;
        } else {
            printf("Error!! Invalid Arguments\n");
            printf("Usage: %s --device [device_num] --chip [chip name] --metrics [metric_names comma separated] --csv --file [filename]\n", argv[0]);
            return -1;
        }
    }

    if (chipName.empty()) 
    {
        DRIVER_API_CALL(cuInit(0));
        DRIVER_API_CALL(cuDeviceGetCount(&deviceCount));

        if (deviceCount == 0) 
        {
            printf("There is no device supporting CUDA.\n");
            return -2;
        }      
        printf("CUDA Device Number: %d\n", deviceNum);

    /* Get chip name for the cuda  device */
        CUpti_Profiler_Initialize_Params profilerInitializeParams = { CUpti_Profiler_Initialize_Params_STRUCT_SIZE };
        CUPTI_API_CALL(cuptiProfilerInitialize(&profilerInitializeParams));
        
        CUpti_Device_GetChipName_Params getChipNameParams = { CUpti_Device_GetChipName_Params_STRUCT_SIZE };
        getChipNameParams.deviceIndex = deviceNum;
        CUPTI_API_CALL(cuptiDeviceGetChipName(&getChipNameParams));
        chipName = getChipNameParams.pChipName;
    }
    printf("Queried Chip : %s \n", chipName.c_str());    

    NVPW_InitializeHost_Params initializeHostParams = { NVPW_InitializeHost_Params_STRUCT_SIZE };
    RETURN_IF_NVPW_ERROR(false, NVPW_InitializeHost(&initializeHostParams));

    NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE };
    metricsContextCreateParams.pChipName = chipName.c_str();
    RETURN_IF_NVPW_ERROR(false, NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams));
    
    std::stringstream outputStream;
    if (metricNames.empty()) {
        auto listSubMetrics = false;
        NVPW_MetricsContext_GetMetricNames_Begin_Params getMetricNameBeginParams = { NVPW_MetricsContext_GetMetricNames_Begin_Params_STRUCT_SIZE };
        getMetricNameBeginParams.pMetricsContext = metricsContextCreateParams.pMetricsContext;
        getMetricNameBeginParams.hidePeakSubMetrics = !listSubMetrics;
        getMetricNameBeginParams.hidePerCycleSubMetrics = !listSubMetrics;
        getMetricNameBeginParams.hidePctOfPeakSubMetrics = !listSubMetrics;
        RETURN_IF_NVPW_ERROR(false, NVPW_MetricsContext_GetMetricNames_Begin(&getMetricNameBeginParams));

        std::cout << "Total metrics on the chip " << getMetricNameBeginParams.numMetrics << "\n";
        for (size_t i = 0; i < getMetricNameBeginParams.numMetrics; i++) {
            if (!GetMetricDetails(getMetricNameBeginParams.ppMetricNames[i], chipName, metricsContextCreateParams, outputStream)) {
                printf("Error!! Failed to get the metric details\n");
            }
        }
  
        NVPW_MetricsContext_GetMetricNames_End_Params getMetricNameEndParams = { NVPW_MetricsContext_GetMetricNames_End_Params_STRUCT_SIZE };
        getMetricNameEndParams.pMetricsContext = metricsContextCreateParams.pMetricsContext;
        RETURN_IF_NVPW_ERROR(false, NVPW_MetricsContext_GetMetricNames_End((NVPW_MetricsContext_GetMetricNames_End_Params *)&getMetricNameEndParams));

    } else {
        for (auto metricName : metricNames) {
            if (!GetMetricDetails(metricName, chipName, metricsContextCreateParams, outputStream)){
                printf("Error!! Failed to get the metric details\n");
            }
        }
    }

    NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE };
    metricsContextDestroyParams.pMetricsContext = metricsContextCreateParams.pMetricsContext;
    RETURN_IF_NVPW_ERROR(false, NVPW_MetricsContext_Destroy((NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams)); 

    if (exportFileName.empty()) {
       PRINT_METRIC_DETAILS(std::cout, outputStream, bIsCSVformat); 
    } else {
        std::ofstream outputFile(exportFileName);
        if (outputFile.is_open())
        {
            PRINT_METRIC_DETAILS(outputFile, outputStream, bIsCSVformat);
            outputFile.close();
            printf("Metric details has been written to %s file.\n", exportFileName.c_str());
        }
    }

    return 0;
}



