/**
 * @file src/evalfn.c
 * @brief Define eval_expr_real related functions
 */

#include "benchmarking.h"
#include "elemop.h"
#include "error.h"
#include "evalfn.h"
#include "exproriented.h"
#include "mem.h"
#include "phyconst.h"
#include "rand.h"
#include "rtconf.h"
#include "testing.h"
#include <ctype.h>

#define PUSH (*++m->s.rsp)
#define POP  (*m->s.rsp--)

#define SET_REAL(v) \
  (real_t) { \
    .elem = {.real = v}, .isnum = true \
  }
#define SET_LAMB(v) \
  (real_t) { \
    .elem = {.lamb = v}, .isnum = false \
  }

#define DEF_ARTHMS(tok, op) \
  static void rpx##tok(machine_t *m) { \
    for (; m->s.rbp + 1 < m->s.rsp; \
         m->s.rbp[1].elem.real op## = POP.elem.real); \
  }
APPLY_ARTHM(DEF_ARTHMS)

#define DEF_BINOP(f) \
  static void rpx_##f(machine_t *m) { \
    for (; m->s.rbp + 1 < m->s.rsp; \
         m->s.rbp[1].elem.real = f(m->s.rbp[1].elem.real, POP.elem.real)); \
  }
MAP(DEF_BINOP, fmod, pow)

static void rpxEql(machine_t *m) {
  for (; m->s.rbp + 1 < m->s.rsp
         && eq(m->s.rsp[-1].elem.real, m->s.rsp->elem.real);
       POP);
  m->s.rbp[1].elem.real = m->s.rbp + 1 == m->s.rsp ?: NAN;
  m->s.rsp = m->s.rbp + 1;
}

#define DEF_LTGT(tok, op) \
  static void rpx##tok(machine_t *m) { \
    for (; m->s.rbp + 1 < m->s.rsp \
           && m->s.rsp[-1].elem.real op m->s.rsp->elem.real; \
         POP); \
    m->s.rbp[1].elem.real = m->s.rbp + 1 == m->s.rsp ?: NAN; \
    m->s.rsp = m->s.rbp + 1; \
  }
APPLY_LTGT(DEF_LTGT)

#define DEF_ONEARGFN(f) \
  static void rpx_##f(machine_t *m) { \
    m->s.rsp->elem.real = f(m->s.rsp->elem.real); \
  }
MAP(DEF_ONEARGFN, sin, cos, tan, fabs, tgamma, ceil, floor, round)

#define DEF_MULTI(name, factor) \
  static void rpx##name(machine_t *m) { \
    m->s.rsp->elem.real *= factor; \
  }
MAP_PAIR(DEF_MULTI, (Negate, -1), (ToRad, pi / 180), (ToDeg, 180 / pi))

#define DEF_TWOCHARFN(name, c1, f1, c2, f2, c3, f3) \
  static void rpx##name(machine_t *m) { \
    switch (*++m->c.rip) { \
      OVERWRITE_REAL(c1, f1) \
      OVERWRITE_REAL(c2, f2) \
      OVERWRITE_REAL(c3, f3) \
    default: \
      [[clang::unlikely]]; \
    } \
  }
MAP_PAIR(
  DEF_TWOCHARFN, (Hyp, 's', sinh, 'c', cosh, 't', tanh),
  (Arc, 's', asin, 'c', acos, 't', atan), (Log, '2', log2, 'c', log10, 'e', log)
)

static void rpxLogBase(machine_t *m) {
  double x = POP.elem.real;
  m->s.rsp->elem.real = log(m->s.rsp->elem.real) / log(x);
}

static void rpxConst(machine_t *m) {
  PUSH = SET_REAL(getConst(*++m->c.rip));
}

static void rpxParse(machine_t *m) {
  char *next = nullptr;
  PUSH = SET_REAL(strtod(m->c.rip, &next));
  m->c.rip = next - 1;
}

static void rpxSpace(machine_t *m) {
  m->c.rip += skipByte(m->c.rip, ' ') - 1;
}

#define CASE_TWOARGFN(c, f) \
  case c: { \
    double x = POP.elem.real; \
    m->s.rsp->elem.real = f(m->s.rsp->elem.real, x); \
  } break;
static void rpxIntFn(machine_t *m) {
  switch (*++m->c.rip) {
    CASE_TWOARGFN('g', gcd)
    CASE_TWOARGFN('l', lcm)
    CASE_TWOARGFN('p', permutation)
    CASE_TWOARGFN('c', combination)
  default:
    [[clang::unlikely]];
  }
}

static void rpxSysFn(machine_t *m) {
  switch (*++m->c.rip) {
  case 'a': // ANS
    PUSH = m->e.info.hist[less(m->e.info.histi, buf_size - 1)];
    break;
  case 'c': // set argc
    m->d.argc[m->d.spi] = (char)POP.elem.real;
    break;
  case 'd': // display
    printany(m->s.rsp->elem.real);
    putchar('\n');
    break;
  case 'h':
    m->s.rsp->elem.real
      = m->e.info.hist[m->e.info.histi - (size_t)m->s.rsp->elem.real].elem.real;
    break;
  case 'n':
    PUSH = SET_REAL(NAN);
    break;
  case 'p':
    m->s.rsp[1] = *m->s.rsp;
    m->s.rsp++;
    break;
  case 'r':
    PUSH = SET_REAL(xorsh0to1());
    break;
  case 's':
    *m->s.rsp = *(m->s.rsp - (int)m->s.rsp->elem.real - 1);
    break;
  default:
    [[clang::unlikely]];
  }
}

static real_t handleFnArgs(machine_t *m) {
  char argnum = *m->c.rip - '0';
  if (m->d.argc[m->d.spi] < argnum) m->d.argc[m->d.spi] = argnum;
  return m->e.args[arg_n - argnum];
}

static void rpxLRegs(machine_t *m) {
  *++m->s.rsp = *++m->c.rip == '0'   ? SET_LAMB((char *)m->c.expr)
              : (isdigit(*m->c.rip)) ? handleFnArgs(m)
              : (islower(*m->c.rip)) ? m->e.info.reg[*m->c.rip - 'a']
                                     : *(real_t *)$panic(ERR_CHAR_NOT_FOUND);
}

static void rpxWRegs(machine_t *m) {
  m->e.info.reg[*++m->c.rip - 'a'] = *m->s.rsp;
}

static void rpxEnd(machine_t *m) {
  m->e.iscontinue = false;
}

static void rpxGrpBgn(machine_t *m) {
  PUSH.elem.lamb = (char *)m->s.rbp;
  m->s.rbp = m->s.rsp;
}

static void rpxGrpEnd(machine_t *m) {
  real_t *rbp = m->s.rbp;
  m->s.rbp = *bit_cast(real_t **, m->s.rbp);
  *rbp = *m->s.rsp;
  m->s.rsp = rbp;
}

static void rpxLmdBgn(machine_t *m) {
  m->c.rip++;
  size_t i = 0;
  for (int nest = 1; *m->c.rip; i++)
    if (m->c.rip[i] == '{') nest++;
    else if (m->c.rip[i] == '}' && !--nest) break;

  *++m->s.rsp = SET_LAMB(xalloc(char, i + 1));
  memcpy(m->s.rsp->elem.lamb, m->c.rip, i);
  m->s.rsp->elem.lamb[i] = '\0';
  m->c.rip += i;
}

static void rpxLmbEnd(machine_t *m) {
  DISPERR("missing open bracket: col ", m->c.rip - m->c.expr);
}

static void callFn(machine_t *m, char const *expr) {
  m->d.spi++;
  m->d.exprs[m->d.spi] = m->c;
  m->c.expr = m->c.rip = expr;
  m->d.callstack[m->d.spi] = m->e.args;
  m->d.argc[m->d.spi] = 0;
  m->e.args = m->s.rsp - arg_n;
  rpxGrpBgn(m);
}

static void retFn(machine_t *m) {
  rpxGrpEnd(m);
  real_t ret = *m->s.rsp;
  m->s.rsp = m->e.args + arg_n;
  m->s.rsp -= m->d.argc[m->d.spi];
  *m->s.rsp = ret;
  m->e.args = m->d.callstack[m->d.spi];
  m->c = m->d.exprs[m->d.spi];
  m->d.spi--;
}

static void rpxRunLmd(machine_t *m) {
  callFn(m, m->s.rsp->elem.lamb);
  rpxEval(m);
  retFn(m);
}

static void rpxCond(machine_t *m) {
  m->s.rsp -= 2;
  real_t *rsp = m->s.rsp;
  *rsp = *(rsp + isnan(rsp[2].elem.real));
}

static void rpxUndfned(machine_t *m) {
  DISPERR(
    codetomsg(ERR_UNKNOWN_CHAR),
    ": ",
    *m->c.rip,
    " at col ",
    m->c.rip - m->c.expr
  );
}

void (*const eval_table['~' - ' ' + 1])(machine_t *) = {
  rpxSpace,   // ' '
  rpxRunLmd,  // '!'
  rpxUndfned, // '"'
  rpxUndfned, // '#'
  rpxLRegs,   // '$'
  rpx_fmod,   // '%'
  rpxWRegs,   // '&'
  rpxUndfned, // '''
  rpxGrpBgn,  // '('
  rpxGrpEnd,  // ')'
  rpxMul,     // '*'
  rpxAdd,     // '+'
  rpxUndfned, // ','
  rpxSub,     // '-'
  rpxUndfned, // '.'
  rpxDiv,     // '/'
  rpxParse,   // '0'
  rpxParse,   // '1'
  rpxParse,   // '2'
  rpxParse,   // '3'
  rpxParse,   // '4'
  rpxParse,   // '5'
  rpxParse,   // '6'
  rpxParse,   // '7'
  rpxParse,   // '8'
  rpxParse,   // '9'
  rpxUndfned, // ':'
  rpxEnd,     // ';'
  rpxLt,      // '<'
  rpxEql,     // '='
  rpxGt,      // '>'
  rpxCond,    // '?'
  rpxSysFn,   // '@'
  rpx_fabs,   // 'A'
  rpxUndfned, // 'B'
  rpx_ceil,   // 'C'
  rpxUndfned, // 'D'
  rpxUndfned, // 'E'
  rpx_floor,  // 'F'
  rpxUndfned, // 'G'
  rpxUndfned, // 'H'
  rpxUndfned, // 'I'
  rpxUndfned, // 'J'
  rpxUndfned, // 'K'
  rpxLogBase, // 'L'
  rpxUndfned, // 'M'
  rpxUndfned, // 'N'
  rpxUndfned, // 'O'
  rpxUndfned, // 'P'
  rpxUndfned, // 'Q'
  rpx_round,  // 'R'
  rpxUndfned, // 'S'
  rpxUndfned, // 'T'
  rpxUndfned, // 'U'
  rpxUndfned, // 'V'
  rpxUndfned, // 'W'
  rpxUndfned, // 'X'
  rpxUndfned, // 'Y'
  rpxUndfned, // 'Z'
  rpxUndfned, // '['
  rpxConst,   // '\'
  rpxUndfned, // ']'
  rpx_pow,    // '^'
  rpxUndfned, // '_'
  rpxUndfned, // '`'
  rpxArc,     // 'a'
  rpxUndfned, // 'b'
  rpx_cos,    // 'c'
  rpxToDeg,   // 'd'
  rpxUndfned, // 'e'
  rpxUndfned, // 'f'
  rpx_tgamma, // 'g'
  rpxHyp,     // 'h'
  rpxIntFn,   // 'i'
  rpxUndfned, // 'j'
  rpxUndfned, // 'k'
  rpxLog,     // 'l'
  rpxNegate,  // 'm'
  rpxUndfned, // 'n'
  rpxUndfned, // 'o'
  rpxUndfned, // 'p'
  rpxUndfned, // 'q'
  rpxToRad,   // 'r'
  rpx_sin,    // 's'
  rpx_tan,    // 't'
  rpxUndfned, // 'u'
  rpxUndfned, // 'v'
  rpxUndfned, // 'w'
  rpxUndfned, // 'x'
  rpxUndfned, // 'y'
  rpxUndfned, // 'z'
  rpxLmdBgn,  // '{'
  rpxUndfned, // '|'
  rpxLmbEnd,  // '}'
  rpxUndfned, // '~'
};

void (*getEvalTable(char c))(machine_t *) {
  return eval_table[c - ' '];
}

void rpxEval(machine_t *restrict m) {
  for (; *m->c.rip && m->e.iscontinue; m->c.rip++) [[clang::likely]]
    getEvalTable (*m->c.rip)(m);
}

void initEvalinfo(machine_t *restrict ret) {
  ret->s.rbp = ret->s.rsp = (real_t *)ret->s.payload - 1;
  ret->e.info = getRRuntimeInfo();
  ret->e.iscontinue = true;
  ret->c.rip = ret->c.expr;
  ret->d.spi = 0;
}

stack_t evalExprRealStack(char const *restrict a_expr) {
  machine_t m;
  initEvalinfo(&m);
  m.c.expr = m.c.rip = a_expr;
  rpxEval(&m);
  setRRuntimeInfo(m.e.info);
  return m.s;
}

/**
 * @brief Evaluate real number expression
 * @param a_expr String of expression
 * @return Expression evaluation result
 */
elem_t evalExprReal(char const *restrict a_expr) {
  machine_t m;
  initEvalinfo(&m);
  m.c.expr = m.c.rip = a_expr;
  rpxEval(&m);
  if (++m.e.info.histi < buf_size) m.e.info.hist[m.e.info.histi] = *m.s.rsp;
  setRRuntimeInfo(m.e.info);
  return (elem_t){
    {m.s.rsp->elem.real},
    m.s.rsp->isnum ? RTYPE_REAL : RTYPE_LAMB,
  };
}

test_table(
  eval_lamb_fn, evalExprReal, (elem_t, char const *),
  {
    {{.rtype = RTYPE_LAMB, .elem = {.lamb = "$1$1+"}}, "{$1$1+}&f"},
    {   {.rtype = RTYPE_REAL, .elem = {.real = 10.0}},      "5$f!"},
}
)

#define eval_expr_real_return_double(expr) evalExprReal(expr).elem.real
test_table(
  eval_real, eval_expr_real_return_double, (double, char const *),
  {
    {11.0,              "5 6 + &x"}, // write reg
    {22.0,                "$x 2 *"}, // load reg
    {33.0, "4 5 (5 6 (6 7 +) +) +"}, // nest grp
    {66.0,               "@a @a +"}, // ans
    { 4.0,            "1 1 + @p +"}, // prev
    { 0.0,                 "\\P s"}, // const
    { 0.0,               "0;hello"}, // comment
    { 5.0,    "3 4 5 {3 @c $1}! +"}, // @c
}
)
#define eval_expr_real_cond(expr) !isnan(eval_expr_real_return_double(expr))
test_table(
  eval_real_cond, eval_expr_real_cond, (bool, char const *),
  {
    { true,      "1.5 1.5 ="},
    {false,          "0 1 ="},
    { true,          "1s1c>"},
    {false,          "1s1c<"},
    { true, "10 100 90 ig ="},
    {false,          "@n@n="},
}
)
test_table(
  eval_lamb, eval_expr_real_return_double, (double, char const *),
  {
    { 8.0,                                      "4 {$1 2 *}!"}, // lamb
    {19.0, "1 5 {$1 3 +}! {5 $1 * {$1 4 -}! {$1 2 /}! $2 +}!"}, // nest lamb
}
)
#undef eval_expr_real_return_double

bench (eval_expr_real) {
  evalExprReal("1 2 3 4 5 +");
  evalExprReal("4 5 ^");
  evalExprReal("1s2^(1c2^)+");
  evalExprReal("  5    6    10    - 5  /");
  evalExprReal("5");
  evalExprReal("@a");
  evalExprReal("10 &x");
  evalExprReal("$x 2 *");
  evalExprReal("2 3 ^ (4 5 *) + (6 7 /) -");
  evalExprReal("\\P 2 / s");
  evalExprReal("\\P 4 / c");
  evalExprReal("2 l2");
  evalExprReal("100 lc");
  evalExprReal("1 0 /");
}