/**
 * @file evalfn-comp.c
 * @brief Define evalExprComplex
 */

#include "benchmarking.h"
#include "elemop.h"
#include "error.h"
#include "evalfn.h"
#include "main.h"
#include "mem.h"
#include "phyconst.h"
#include "rand.h"
#include "testing.h"
#include <ctype.h>

/**
 * @brief Evaluate comp number expression
 * @param[in] expr String of expression
 * @param[in,out] operand_stack Return value
 * @warning Possible stack overflow with very long expressions
 */
[[gnu::nonnull]] size_t
evalExprComplexStack(char const *expr, elem_t *operand_stack) {
  elem_t *rsp = operand_stack, *rbp = operand_stack;
  rtinfo_t info_c = getRuntimeInfo();

  for (;; expr++) {
    switch (*expr) {
    case ' ':
    case '\t':
    case '\n':
      continue;
    case '0' ... '9':
      (++rsp)->rtype = RTYPE_COMP;
      rsp->elem.comp = strtod(expr, (char **)&expr);
      expr--;
      break;
    case '[':
      (++rsp)->rtype = RTYPE_MATR;
      expr++;
      matrix_t val = {.matrix = xalloc(comp, mat_init_size)};
      matrix curelem = val.matrix;
      val.cols = (size_t)evalExprComplex(expr).elem.comp;
      expr += skipUntil(expr, ';');
      elem_t stack[buf_size] = {};
      size_t len = evalExprComplexStack(expr, stack);
      for (size_t i = 0; i < len; i++) *curelem++ = stack[i + 1].elem.comp;
      val.rows = len / val.cols;
      rsp->elem.matr = val;
      expr += skipUntil(expr, ']') - 1;
      continue;
    case '(':
      (++rsp)->elem.real = bit_cast(double, rbp - operand_stack);
      rbp = rsp;
      break;
    case ')':
      rbp = operand_stack + bit_cast(long, rbp->elem.real);
      rsp--;
      *rsp = *(rsp + 1);
      break;

      MAP_PAIR(OP_CASE_ELEM, (Add, +), (Sub, -), (Mul, *), (Div, /), (Pow, ^))
      MAP_PAIR(OVERWRITE_COMP, ('A', fabs), ('s', sin), ('c', cos), ('t', tan))

    case '~': {
      matrix_t __ dropmatr = rsp->elem.matr;
      rsp->elem.matr = inverseMatrix(&rsp->elem.matr);
    } break;

    case '=':
      for (; rbp + 1 < rsp && eq(rsp - 1, rsp); rsp--);
      (rbp + 1)->elem.real = rbp + 1 == rsp;
      if (rbp + 1 != rsp) rsp = rbp + 1;
      break;

    case 'a':
      switch (*++expr) {
        MAP_PAIR(OVERWRITE_COMP, ('s', asin), ('c', acos), ('t', atan))
      default:
        DISPERR("unknown fn: ", *expr);
      }
      break;

    case 'h':
      switch (*++expr) {
        MAP_PAIR(OVERWRITE_COMP, ('s', sinh), ('c', cosh), ('t', tanh))
      default:
        DISPERR("unknown fn: ", *expr);
      }
      break;

    case 'l':
      rsp->elem.comp = clog(rsp->elem.comp);
      break;

    case 'L': { // log with base
      double x = creal((rsp--)->elem.comp);
      rsp->elem.comp = log(rsp->elem.comp) / log(x);
    } break;

#define MUL(c, factor) \
  case c: \
    rsp->elem.comp *= factor; \
    break;

      MAP_PAIR(MUL, ('r', pi / 180), ('d', 180 / pi), ('m', -1), ('i', I))

    case 'p': { // polar
      comp theta = (rsp--)->elem.comp;
      rsp->elem.comp
        = rsp->elem.comp * cos(theta) + I * rsp->elem.comp * sin(theta);
    } break;

    case '@':   // system functions
      switch (*++expr) {
      case 'a': // ANS
        elemSet(++rsp, &info_c.hist[info_c.histi]);
        break;
      case 'd':
        print_complex(rsp->elem.comp);
        break;
      case 'h': // history operation
        elemSet(rsp, &info_c.hist[info_c.histi - (size_t)rsp->elem.real]);
        break;
      case 'n':
        elemSet(++rsp, &(elem_t){.rtype = RTYPE_COMP, .elem = {.comp = NAN}});
        break;
      case 'p': // prev stack value
        rsp++;
        elemSet(rsp, rsp - 1);
        break;
      case 'r':
        (++rsp)->elem.comp = xorsh0to1();
        break;
      case 's': // stack value operation
        elemSet(rsp, rsp - (int)rsp->elem.real - 1);
        break;
      default:
        [[clang::unlikely]];
      }
      break;

    case '\\': // special variables and CONSTANTS
      (++rsp)->elem.comp = getConst(*++expr);
      break;

    case '$': // register
      if (islower(*++expr)) [[clang::likely]] {
        elem_t const *rhs = &info_c.reg[*expr - 'a'];
        elemSet(++rsp, rhs);
      }
      break;

    case '&':
      elemSet(&info_c.reg[*++expr - 'a'], rsp);
      break;

    case ';': // comment
    case ']':
    case '\0':
      goto end;
    default:
      DISPERR("unknown char: ", *expr);
    }
  }

end:
  if (rsp->rtype == RTYPE_MATR) {
    elem_t *rhs = &info_c.hist[++info_c.histi];
    if (rhs->rtype == RTYPE_MATR) nfree(rhs->elem.matr.matrix);
    *rhs = *rsp;
  } else if (info_c.histi < buf_size)
    info_c.hist[++info_c.histi].elem.comp = rsp->elem.comp;
  setRuntimeInfo(info_c);
  return (size_t)(rsp - operand_stack);
}

elem_t evalExprComplex(char const *a_expr) {
  elem_t s[buf_size] = {};
  size_t len = evalExprComplexStack(a_expr, s);
  return s[len];
}

#define eval_expr_complex_return_complex(expr) evalExprComplex(expr).elem.comp
test_table(
  eval_complex, eval_expr_complex_return_complex, (comp, char const *),
  {
    {1.0 + 2.0i,        "1 2i +"},
    {   1024.0i,        "4i 5 ^"},
    {      4.0i, "1 1i+(2 2i+)*"},
}
)
test_table(
  eval_complex_comp, eval_expr_complex_return_complex, (comp, char const *),
  {
    {1.2984575814159773 + 0.6349639147847361i, "1 1i+s"},
    {1.1447298858494002 + 1.5707963267948967i,  "\\Pil"},
}
)
#undef eval_expr_complex_return_complex

test (eval_expr_complex) {
  matrix_t resultm;

  // Test matrix addition
  resultm = evalExprComplex("[2; 1 2 3 4][2; 5 6 7 8]+").elem.matr;
  expecteq(6.0, resultm.matrix[0]);
  expecteq(8.0, resultm.matrix[1]);
  expecteq(10.0, resultm.matrix[2]);
  expecteq(12.0, resultm.matrix[3]);

  // Test matrix multiplication
  char const *expr = "[2; 1 2 3 4][2; 5 6 7 8]*";
  resultm = evalExprComplex(expr).elem.matr;
  expecteq(2, resultm.rows);
  expecteq(2, resultm.cols);
  expecteq(19.0, resultm.matrix[0]);
  expecteq(22.0, resultm.matrix[1]);
  expecteq(43.0, resultm.matrix[2]);
  expecteq(50.0, resultm.matrix[3]);

  // Test matrix inverse
  expr = "[2; 1 2 3 4]~";
  resultm = evalExprComplex(expr).elem.matr;
  expecteq(2, resultm.rows);
  expecteq(2, resultm.cols);
  expecteq(-2.0, resultm.matrix[0]);
  expecteq(1.0, resultm.matrix[1]);
  expecteq(1.5, resultm.matrix[2]);
  expecteq(-0.5, resultm.matrix[3]);

  // Scalar multiplication
  expr = "[3; 5 6 7] 5 *";
  resultm = evalExprComplex(expr).elem.matr;
  expecteq(1, resultm.rows);
  expecteq(3, resultm.cols);
  expecteq(25, resultm.matrix[0]);
  expecteq(30, resultm.matrix[1]);
  expecteq(35, resultm.matrix[2]);
}

bench (eval_expr_complex) {
  evalExprComplex("1 2 3 4 5 +");
  evalExprComplex("4 5 ^");
  evalExprComplex("1s2^(1c2^)+");
  evalExprComplex("  5    6    10    - 5  /");
  evalExprComplex("5");
  evalExprComplex("@a");
  evalExprComplex("10 &x");
  evalExprComplex("$x 2 *");
  evalExprComplex("2 3 ^ (4 5 *) + (6 7 /) -");
  evalExprComplex("\\P 2 / s");
  evalExprComplex("\\P 4 / c");
  evalExprComplex("2 l2");
  evalExprComplex("100 lc");
  evalExprComplex("1 0 /");
}

bench (eval_matrix) {
  evalExprComplex("[2; 1 2 3 4 ][2; 5 6 7 8 ]+");
  evalExprComplex("[3; 4 1 4 6 5 7 3 6 7 ]~");
  evalExprComplex("[1; 4 5 ][2; 6 7 ]*");
  evalExprComplex("[2; 7 6 5 4 ] 6 *");
  evalExprComplex("[2; 9 0 5 1 ] 4 ^");
  evalExprComplex("[2; 5 4 3 2 ][2; 4 8 2 1 ]/");
}