/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Comp       eq              Equal
      Comp       ne              Not equal
      Comp       le              Less equal
      Comp       lt              Less than
      Comp       ge              Greater equal
      Comp       gt              Greater than
*/

#include <cdi.h>

#include <utility>

#include "cdo_options.h"
#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_cdi_wrapper.h"
#include "cdo_fill.h"

static void
operatorIsEQ(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : IS_EQUAL(a1[i], a2[i]));
}

static void
operatorIsNE(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : IS_NOT_EQUAL(a1[i], a2[i]));
}

static void
operatorIsLE(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : a1[i] <= a2[i]);
}

static void
operatorIsLT(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : a1[i] < a2[i]);
}

static void
operatorIsGE(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : a1[i] >= a2[i]);
}

static void
operatorIsGT(size_t n, double mv1, double mv2, const Varray<double> &a1, const Varray<double> &a2, Varray<double> &a3)
{
  for (size_t i = 0; i < n; i++) a3[i] = (DBL_IS_EQUAL(a1[i], mv1) || DBL_IS_EQUAL(a2[i], mv2) ? mv1 : a1[i] > a2[i]);
}

void *
Comp(void *process)
{
  enum
  {
    FILL_NONE,
    FILL_TS,
    FILL_REC
  };
  int filltype = FILL_NONE;
  int nrecs, nrecs2;
  int varID, levelID;
  Varray2D<double> vardata;

  cdoInitialize(process);

  const auto EQ = cdoOperatorAdd("eq", 0, 0, nullptr);
  const auto NE = cdoOperatorAdd("ne", 0, 0, nullptr);
  const auto LE = cdoOperatorAdd("le", 0, 0, nullptr);
  const auto LT = cdoOperatorAdd("lt", 0, 0, nullptr);
  const auto GE = cdoOperatorAdd("ge", 0, 0, nullptr);
  const auto GT = cdoOperatorAdd("gt", 0, 0, nullptr);

  const auto operatorID = cdoOperatorID();

  operatorCheckArgc(0);

  auto streamID1 = cdoOpenRead(0);
  auto streamID2 = cdoOpenRead(1);

  auto vlistID1 = cdoStreamInqVlist(streamID1);
  auto vlistID2 = cdoStreamInqVlist(streamID2);

  auto taxisID1 = vlistInqTaxis(vlistID1);
  auto taxisID2 = vlistInqTaxis(vlistID2);

  auto ntsteps1 = vlistNtsteps(vlistID1);
  auto ntsteps2 = vlistNtsteps(vlistID2);
  if (ntsteps1 == 0) ntsteps1 = 1;
  if (ntsteps2 == 0) ntsteps2 = 1;

  auto fillstream1 = false;

  if (vlistNrecs(vlistID1) != 1 && vlistNrecs(vlistID2) == 1)
    {
      filltype = FILL_REC;
      cdoPrint("Filling up stream2 >%s< by copying the first record.", cdoGetStreamName(1));
      if (ntsteps2 != 1) cdoAbort("stream2 has more than 1 timestep!");
    }
  else if (vlistNrecs(vlistID1) == 1 && vlistNrecs(vlistID2) != 1)
    {
      filltype = FILL_REC;
      cdoPrint("Filling up stream1 >%s< by copying the first record.", cdoGetStreamName(0));
      if (ntsteps1 != 1) cdoAbort("stream1 has more than 1 timestep!");
      fillstream1 = true;
      std::swap(streamID1, streamID2);
      std::swap(vlistID1, vlistID2);
      std::swap(taxisID1, taxisID2);
    }

  if (filltype == FILL_NONE) vlistCompare(vlistID1, vlistID2, CMP_ALL);

  nospec(vlistID1);
  nospec(vlistID2);

  const auto gridsizemax = vlistGridsizeMax(vlistID1);
  Varray<double> array1(gridsizemax), array2(gridsizemax), array3(gridsizemax);

  double *arrayx1 = array1.data();
  double *arrayx2 = array2.data();

  if (Options::cdoVerbose) cdoPrint("Number of timesteps: file1 %d, file2 %d", ntsteps1, ntsteps2);

  if (filltype == FILL_NONE)
    {
      if (ntsteps1 != 1 && ntsteps2 == 1)
        {
          filltype = FILL_TS;
          cdoPrint("Filling up stream2 >%s< by copying the first timestep.", cdoGetStreamName(1));
        }
      else if (ntsteps1 == 1 && ntsteps2 != 1)
        {
          filltype = FILL_TS;
          cdoPrint("Filling up stream1 >%s< by copying the first timestep.", cdoGetStreamName(0));
          fillstream1 = true;
          std::swap(streamID1, streamID2);
          std::swap(vlistID1, vlistID2);
          std::swap(taxisID1, taxisID2);
        }

      if (filltype == FILL_TS) cdoFillTs(vlistID2, vardata);
    }

  if (fillstream1)
    {
      arrayx1 = array2.data();
      arrayx2 = array1.data();
    }

  VarList varList1, varList2;
  varListInit(varList1, vlistID1);
  varListInit(varList2, vlistID2);

  const auto vlistID3 = vlistDuplicate(vlistID1);

  const auto taxisID3 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID3, taxisID3);

  const auto streamID3 = cdoOpenWrite(2);
  cdoDefVlist(streamID3, vlistID3);

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      if (tsID == 0 || filltype == FILL_NONE)
        {
          nrecs2 = cdoStreamInqTimestep(streamID2, tsID);
          if (nrecs2 == 0) cdoAbort("Input streams have different number of timesteps!");
        }

      taxisCopyTimestep(taxisID3, taxisID1);
      cdoDefTimestep(streamID3, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          size_t nmiss1;
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoReadRecord(streamID1, arrayx1, &nmiss1);

          if (tsID == 0 || filltype == FILL_NONE)
            {
              if (recID == 0 || filltype != FILL_REC)
                {
                  size_t nmiss2;
                  cdoInqRecord(streamID2, &varID, &levelID);
                  cdoReadRecord(streamID2, arrayx2, &nmiss2);
                }

              if (filltype == FILL_TS)
                {
                  const auto offset = varList2[varID].gridsize * levelID;
                  arrayCopy(varList2[varID].gridsize, arrayx2, &vardata[varID][offset]);
                }
            }
          else if (filltype == FILL_TS)
            {
              const auto offset = varList2[varID].gridsize * levelID;
              arrayCopy(varList2[varID].gridsize, &vardata[varID][offset], arrayx2);
            }

          const auto datatype1 = varList1[varID].datatype;
          const auto gridsize1 = varList1[varID].gridsize;
          auto missval1 = varList1[varID].missval;

          const auto xvarID = (filltype == FILL_REC) ? 0 : varID;
          const auto datatype2 = varList2[xvarID].datatype;
          const auto gridsize2 = varList2[xvarID].gridsize;
          auto missval2 = varList2[xvarID].missval;

          if (gridsize1 != gridsize2)
            cdoAbort("Streams have different gridsize (gridsize1 = %zu; gridsize2 = %zu)!", gridsize1, gridsize2);

          const auto gridsize = gridsize1;

          if (datatype1 != datatype2)
            {
              if (datatype1 == CDI_DATATYPE_FLT32 && datatype2 == CDI_DATATYPE_FLT64)
                {
                  missval2 = (float) missval2;
                  for (size_t i = 0; i < gridsize; i++) array2[i] = (float) array2[i];
                }
              else if (datatype1 == CDI_DATATYPE_FLT64 && datatype2 == CDI_DATATYPE_FLT32)
                {
                  missval1 = (float) missval1;
                  for (size_t i = 0; i < gridsize; i++) array1[i] = (float) array1[i];
                }
            }

          if (nmiss1 > 0) cdo_check_missval(missval1, varList1[varID].name);
          // if (nmiss2 > 0) cdo_check_missval(missval2, varList2[varID].name);

          // clang-format off
          if      (operatorID == EQ) operatorIsEQ(gridsize, missval1, missval2, array1, array2, array3);
          else if (operatorID == NE) operatorIsNE(gridsize, missval1, missval2, array1, array2, array3);
          else if (operatorID == LE) operatorIsLE(gridsize, missval1, missval2, array1, array2, array3);
          else if (operatorID == LT) operatorIsLT(gridsize, missval1, missval2, array1, array2, array3);
          else if (operatorID == GE) operatorIsGE(gridsize, missval1, missval2, array1, array2, array3);
          else if (operatorID == GT) operatorIsGT(gridsize, missval1, missval2, array1, array2, array3);
          else cdoAbort("Operator not implemented!");
          // clang-format on

          const auto nmiss3 = varrayNumMV(gridsize, array3, missval1);
          cdoDefRecord(streamID3, varID, levelID);
          cdoWriteRecord(streamID3, array3.data(), nmiss3);
        }

      tsID++;
    }

  cdoStreamClose(streamID3);
  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
