/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

#include "cdo_varlist.h"
#include "cdo_cdi_wrapper.h"
#include "cdo_output.h"
#include "util_string.h"
#include "compare.h"
#include "stdnametable.h"
#include "cdo_vlist.h"

static bool
isIntType(int dataType)
{
  return (dataType == CDI_DATATYPE_UINT8 || dataType == CDI_DATATYPE_UINT16 || dataType == CDI_DATATYPE_INT16);
}

static bool
isFloatType(int dataType)
{
  return (dataType == CDI_DATATYPE_FLT32 || dataType == CDI_DATATYPE_CPX32);
}

void
varList_init(VarList &varList, int vlistID)
{
  auto numVars = vlistNvars(vlistID);
  varList.resize(numVars);

  for (int varID = 0; varID < numVars; ++varID)
    {
      auto &var = varList[varID];
      var.name = cdo::inq_var_name(vlistID, varID);
      var.longname = cdo::inq_var_longname(vlistID, varID);
      var.units = cdo::inq_var_units(vlistID, varID);
      var.gridID = vlistInqVarGrid(vlistID, varID);
      var.zaxisID = vlistInqVarZaxis(vlistID, varID);
      var.timetype = vlistInqVarTimetype(vlistID, varID);
      var.tsteptype = vlistInqVarTsteptype(vlistID, varID);
      var.gridType = gridInqType(var.gridID);
      var.gridsize = gridInqSize(var.gridID);
      var.zaxisType = zaxisInqType(var.zaxisID);
      var.nlevels = zaxisInqSize(var.zaxisID);
      var.datatype = vlistInqVarDatatype(vlistID, varID);
      var.missval = vlistInqVarMissval(vlistID, varID);
      var.code = vlistInqVarCode(vlistID, varID);
      var.param = vlistInqVarParam(vlistID, varID);
      var.nwpv = vlistInqVarNumber(vlistID, varID);
      var.isConstant = (var.timetype == TIME_CONSTANT);

      if (Options::CDO_Memtype == MemType::Native)
        {
          double addoffset = 0.0, scalefactor = 1.0;
          auto haveAddoffset = (cdiInqKeyFloat(vlistID, varID, CDI_KEY_ADDOFFSET, &addoffset) == CDI_NOERR);
          auto haveScalefactor = (cdiInqKeyFloat(vlistID, varID, CDI_KEY_SCALEFACTOR, &scalefactor) == CDI_NOERR);
          auto isPacked = (haveAddoffset || haveScalefactor);
          auto useFloatType = (var.datatype == CDI_UNDEFID) || isFloatType(var.datatype) || (isIntType(var.datatype) && !isPacked);
          var.memType = useFloatType ? MemType::Float : MemType::Double;
        }
      else { var.memType = Options::CDO_Memtype; }
    }
}

void
varListSetMemtype(VarList &varList, MemType memType)
{
  for (auto &var : varList) var.memType = memType;
}

void
varListSetUniqueMemtype(VarList &varList)
{
  int numVars = varList.size();
  if (numVars)
    {
      auto memtype = varList[0].memType;
      int varID;
      for (varID = 1; varID < numVars; ++varID)
        {
          if (varList[varID].memType != memtype) break;
        }
      if (varID < numVars) varListSetMemtype(varList, MemType::Double);
    }
}

int
varList_numConstVars(const VarList &varList)
{
  int numConstVars = 0;
  int numVars = varList.size();
  for (int varID = 0; varID < numVars; ++varID)
    {
      const auto &var = varList[varID];
      if (var.timetype == TIME_CONSTANT) numConstVars++;
    }
  return numConstVars;
}

int
varList_numVaryingVars(const VarList &varList)
{
  int numVaryingVars = 0;
  int numVars = varList.size();
  for (int varID = 0; varID < numVars; ++varID)
    {
      const auto &var = varList[varID];
      if (var.timetype == TIME_VARYING) numVaryingVars++;
    }
  return numVaryingVars;
}

VarIDs
search_varIDs(const VarList &varList, int vlistID, int numFullLevels)
{
  VarIDs varIDs;

  auto numVars = vlistNvars(vlistID);

  auto useTable = false;
  for (int varID = 0; varID < numVars; ++varID)
    {
      auto tableNum = tableInqNum(vlistInqVarTable(vlistID, varID));
      if (tableNum > 0 && tableNum < 255)
        {
          useTable = true;
          break;
        }
    }

  if (Options::cdoVerbose && useTable) cdo_print("Using code tables!");

  char paramstr[32];
  gribcode_t gribcodes;

  for (int varID = 0; varID < numVars; ++varID)
    {
      auto &var = varList[varID];
      auto nlevels = var.nlevels;
      auto instNum = institutInqCenter(vlistInqVarInstitut(vlistID, varID));
      auto tableNum = tableInqNum(vlistInqVarTable(vlistID, varID));

      auto code = var.code;

      cdiParamToString(var.param, paramstr, sizeof(paramstr));
      int pnum, pcat, pdis;
      cdiDecodeParam(var.param, &pnum, &pcat, &pdis);
      if (pdis >= 0 && pdis < 255) code = -1;

      if (useTable)
        {
          if (tableNum == 2) { wmo_gribcodes(&gribcodes); }
          else if (tableNum == 128 || tableNum == 0 || tableNum == 255) { echam_gribcodes(&gribcodes); }
          //  KNMI: HIRLAM model version 7.2 uses tableNum=1    (LAMH_D11*)
          //  KNMI: HARMONIE model version 36 uses tableNum=1   (grib*) (opreational NWP version)
          //  KNMI: HARMONIE model version 38 uses tableNum=253 (grib,grib_md) and tableNum=1 (grib_sfx) (research version)
          else if (tableNum == 1 || tableNum == 253) { hirlam_harmonie_gribcodes(&gribcodes); }
        }
      else { echam_gribcodes(&gribcodes); }

      if (Options::cdoVerbose)
        cdo_print("Center=%d  TableNum=%d  Code=%d  Param=%s  Varname=%s  varID=%d", instNum, tableNum, code, paramstr, var.name,
                  varID);

      if (code <= 0 || code == 255)
        {
          auto varname = string_to_lower(cdo::inq_var_name(vlistID, varID));
          auto stdname = string_to_lower(cdo::inq_key_string(vlistID, varID, CDI_KEY_STDNAME));

          code = stdname_to_echamcode(stdname);
          if (code == -1)
            {
              //                                  ECHAM                 ECMWF
              // clang-format off
              if      (-1 == varIDs.sgeopotID && (varname == "geosp" || varname == "z")) code = gribcodes.geopot;
              else if (-1 == varIDs.tempID    && (varname == "st"    || varname == "t")) code = gribcodes.temp;
              else if (-1 == varIDs.psID      && (varname == "aps"   || varname == "sp")) code = gribcodes.ps;
              else if (-1 == varIDs.psID      &&  varname == "ps") code = gribcodes.ps;
              else if (-1 == varIDs.lnpsID    && (varname == "lsp"   || varname == "lnsp")) code = gribcodes.lsp;
              else if (-1 == varIDs.lnpsID2   &&  varname == "lnps") code = 777;
              else if (-1 == varIDs.geopotID  &&  stdname == "geopotential_full") code = gribcodes.geopot;
              else if (-1 == varIDs.tempID    &&  varname == "t") code = gribcodes.temp;
              else if (-1 == varIDs.humID     &&  varname == "q") code = gribcodes.hum;
              // else if (varname == "clwc") code = 246;
              // else if (varname == "ciwc") code = 247;
              // clang-format on
            }
        }

      // clang-format off
      if      (code == gribcodes.geopot  && nlevels == 1)             varIDs.sgeopotID = varID;
      else if (code == gribcodes.geopot  && nlevels == numFullLevels) varIDs.geopotID = varID;
      else if (code == gribcodes.temp    && nlevels == numFullLevels) varIDs.tempID = varID;
      else if (code == gribcodes.ps      && nlevels == 1)             varIDs.psID = varID;
      else if (code == gribcodes.lsp     && nlevels == 1)             varIDs.lnpsID = varID;
      else if (code == 777               && nlevels == 1)             varIDs.lnpsID2 = varID;
      else if (code == gribcodes.gheight && nlevels == numFullLevels) varIDs.gheightID = varID;
      else if (code == gribcodes.gheight && nlevels == numFullLevels + 1) varIDs.gheightID = varID;
      else if (code == gribcodes.hum     && nlevels == numFullLevels) varIDs.humID = varID;
      // else if (code == 246 && nlevels == numFullLevels) varIDs.clwcID = varID;
      // else if (code == 247 && nlevels == numFullLevels) varIDs.ciwcID = varID;
      // clang-format on
    }

  return varIDs;
}

void
varList_map(const VarList &varList1, const VarList &varList2, CmpVlist cmpFlag, int mapFlag, std::map<int, int> &mapOfVarIDs)
{
  auto flag = static_cast<int>(cmpFlag);
  int nvars1 = varList1.size();
  int nvars2 = varList2.size();

  if (mapFlag == 2)
    {
      for (int varID2 = 0; varID2 < nvars2; ++varID2)
        {
          int varID1;
          for (varID1 = 0; varID1 < nvars1; ++varID1)
            {
              if (varList1[varID1].name == varList2[varID2].name) break;
            }
          if (varID1 == nvars1) { cdo_abort("Variable %s not found in first input stream!", varList2[varID2].name); }
          else { mapOfVarIDs[varID1] = varID2; }
        }
    }
  else
    {
      for (int varID1 = 0; varID1 < nvars1; ++varID1)
        {
          int varID2;
          for (varID2 = 0; varID2 < nvars2; ++varID2)
            {
              if (varList1[varID1].name == varList2[varID2].name) break;
            }
          if (varID2 == nvars2)
            {
              if (mapFlag == 3) continue;
              cdo_abort("Variable %s not found in second input stream!", varList1[varID1].name);
            }
          else { mapOfVarIDs[varID1] = varID2; }
        }
    }

  if (mapOfVarIDs.empty()) cdo_abort("No variable found that occurs in both streams!");

  if (Options::cdoVerbose)
    for (int varID1 = 0; varID1 < nvars1; ++varID1)
      {
        const auto &it = mapOfVarIDs.find(varID1);
        if (it != mapOfVarIDs.end())
          cdo_print("Variable %d:%s mapped to %d:%s", varID1, varList1[varID1].name, it->second, varList2[it->second].name);
      }

  if (mapOfVarIDs.size() > 1)
    {
      auto varID2 = mapOfVarIDs.begin()->second;
      for (auto it = ++mapOfVarIDs.begin(); it != mapOfVarIDs.end(); ++it)
        {
          if (it->second < varID2)
            cdo_abort("Variable names must be sorted, use CDO option --sortname to sort the parameter by name (NetCDF only)!");

          varID2 = it->second;
        }
    }

  for (auto it = mapOfVarIDs.begin(); it != mapOfVarIDs.end(); ++it)
    {
      auto varID1 = it->first;
      auto varID2 = it->second;

      if (flag & static_cast<int>(CmpVlist::GridSize))
        {
          if (varList1[varID1].gridsize != varList2[varID2].gridsize) cdo_abort("Grid size of the input fields do not match!");
        }

      if (flag & static_cast<int>(CmpVlist::NumLevels))
        {
          if (zaxis_check_levels(varList1[varID1].zaxisID, varList2[varID2].zaxisID) != 0) break;
        }

      if (flag & static_cast<int>(CmpVlist::Grid) && varID1 == mapOfVarIDs.begin()->first)
        {
          auto gridID1 = varList1[varID1].gridID;
          auto gridID2 = varList2[varID2].gridID;
          if (gridID1 != gridID2) cdo_compare_grids(gridID1, gridID2);
        }
    }
}

int
varList_get_psvarid(const VarList &varList, int zaxisID)
{
  auto psname = cdo::inq_key_string(zaxisID, CDI_GLOBAL, CDI_KEY_PSNAME);
  if (psname.size())
    {
      for (int varID = 0, numVars = varList.size(); varID < numVars; ++varID)
        {
          if (varList[varID].name == psname) return varID;
        }
      if (Options::cdoVerbose) cdo_warning("Surface pressure variable not found - %s", psname);
    }

  return -1;
}

static void
varList_check_names(const VarList &varList1, const VarList &varList2)
{
  int numVars = varList1.size();

  std::vector<std::string> names1(numVars);
  std::vector<std::string> names2(numVars);
  for (int varID = 0; varID < numVars; ++varID) names1[varID] = varList1[varID].name;
  for (int varID = 0; varID < numVars; ++varID) names2[varID] = varList2[varID].name;

  ranges::sort(names1);
  ranges::sort(names2);

  int varID;
  for (varID = 0; varID < numVars; ++varID)
    if (names1[varID] != names2[varID]) break;

  if (varID == numVars) cdo_print("Use CDO option --sortname to sort the parameter by name (NetCDF only)!");
}

static void
varList_print_missing_vars(const VarList &varList1, const VarList &varList2)
{
  int numVars1 = varList1.size();
  int numVars2 = varList2.size();

  if (numVars1 > numVars2)
    {
      for (int varID1 = 0; varID1 < numVars1; ++varID1)
        {
          int varID2;
          for (varID2 = 0; varID2 < numVars2; ++varID2)
            {
              if (varList1[varID1].name == varList2[varID2].name) break;
            }
          if (varID2 == numVars2) cdo_print("Variable %s not found in second input stream!", varList1[varID1].name);
        }
    }
  else
    {
      for (int varID2 = 0; varID2 < numVars2; ++varID2)
        {
          int varID1;
          for (varID1 = 0; varID1 < numVars1; ++varID1)
            {
              if (varList1[varID1].name == varList2[varID2].name) break;
            }
          if (varID1 == numVars1) cdo_print("Variable %s not found in first input stream!", varList2[varID2].name);
        }
    }
}

static int
varList_numRecs(const VarList &varList)
{
  int numRecs = 0;
  for (int varID = 0, numVars = varList.size(); varID < numVars; ++varID) numRecs += varList[varID].nlevels;
  return numRecs;
}

void
varList_compare(const VarList &varList1, const VarList &varList2, CmpVlist cmpFlag)
{
  auto flag = static_cast<int>(cmpFlag);
  auto doCheckNames = false;

  int numVars = varList1.size();

  if (numVars != (int) varList2.size())
    {
      varList_print_missing_vars(varList1, varList2);
      cdo_abort("Input streams have different number of variables per timestep!");
    }

  if (varList_numRecs(varList1) != varList_numRecs(varList2))
    cdo_abort("Input streams have different number of %s per timestep!", (numVars == 1) ? "layers" : "records");

  for (int varID = 0; varID < numVars; ++varID)
    {
      if (numVars > 1)
        {
          if (flag & static_cast<int>(CmpVlist::Name))
            {
              if (string_to_lower(varList1[varID].name) != string_to_lower(varList2[varID].name))
                {
                  cdo_warning("Input streams have different parameter names!");
                  doCheckNames = true;
                  flag -= static_cast<int>(CmpVlist::Name);
                }
            }
        }

      if (flag & static_cast<int>(CmpVlist::GridSize))
        {
          if (varList1[varID].gridsize != varList2[varID].gridsize)
            {
              cdo_abort("Grid size of the input field '%s' do not match!", varList1[varID].name);
            }
        }

      if (flag & static_cast<int>(CmpVlist::NumLevels))
        {
          if (zaxis_check_levels(varList1[varID].zaxisID, varList2[varID].zaxisID) != 0) break;
        }
    }

  if (flag & static_cast<int>(CmpVlist::Grid))
    {
      if (varList1[0].gridID != varList2[0].gridID) cdo_compare_grids(varList1[0].gridID, varList2[0].gridID);
    }

  if (doCheckNames) varList_check_names(varList1, varList2);
}

void
vlist_compare(int vlistID1, int vlistID2, CmpVlist cmpFlag)
{
  VarList varList1;
  VarList varList2;
  varList_init(varList1, vlistID1);
  varList_init(varList2, vlistID2);
  varList_compare(varList1, varList2, cmpFlag);
}
