#include <solver/algebra.h>
#include <stdlib.h>
#include <stdio.h>

static void assert_is_equals(struct solver_Algebra *a, uint64_t lhs, uint64_t rhs) {
  uint64_t eq = solver_Algebra_equals(a, lhs, rhs);
  if(solver_Algebra_to_boolean(a, eq) != solver_Algebra_boolean(a, true)) {
    printf("assert_is_equals\n");
    printf("lhs: ");
    solver_Algebra_show(a, lhs);
    printf("rhs: ");
    solver_Algebra_show(a, rhs);
    fprintf(stderr, "failed\n");
    exit(1);
  }
}

static void assert_is_true(struct solver_Algebra *a, const char *input) {
  uint64_t expr = solver_Algebra_parse(a, input);
  if(solver_Algebra_to_boolean(a, expr) != solver_Algebra_boolean(a, true)) {
    printf("assert_is_true\n");
    printf("input: %s\n", input);
    printf("parsed: ");
    solver_Algebra_show(a, expr);
    fprintf(stderr, "failed\n");
    exit(1);
  }
}

static void assert_is_false(struct solver_Algebra *a, const char *input) {
  uint64_t expr = solver_Algebra_parse(a, input);
  if(solver_Algebra_to_boolean(a, expr) != solver_Algebra_boolean(a, false)) {
    printf("assert_is_false\n");
    printf("input: %s\n", input);
    printf("parsed: ");
    solver_Algebra_show(a, expr);
    fprintf(stderr, "failed\n");
    exit(1);
  }
}

static void test_simple(void) {
  struct solver_Algebra a; 
  solver_Algebra_init(&a);

  assert_is_true(&a, "2 * x == x + x");
  assert_is_false(&a, "3 * x == x + x + 2");
  assert_is_true(&a, "3 * x == x + x + x");
  assert_is_true(&a, "5 + x + 2 == 7 + x");
  assert_is_true(&a, "3 + x + 5 + x == 8 + 2 * x");
  assert_is_true(&a, "x + y + z + x + y + z == 2*x + 2*y + 2*z");
  assert_is_true(&a, "10 - x == 10 + x * -1");
  assert_is_true(&a, "10 + x == 10 + x * -1 * -1");
  assert_is_true(&a, "10 + x == 10 + x * -1 * -1");
  assert_is_true(&a, "10 == 100 / 10");

  solver_Algebra_free(&a); 
}

static void substitute(struct solver_Algebra *a, const char *input_, const char *needle_, const char *replacement_, const char *expected_) {
  uint64_t input = solver_Algebra_parse(a, input_); 
  uint64_t needle = solver_Algebra_parse(a, needle_); 
  uint64_t replacement = solver_Algebra_parse(a, replacement_); 
  uint64_t expected = solver_Algebra_parse(a, expected_); 

  uint64_t substituted = solver_Algebra_substitute(a, input, needle, replacement);

  assert_is_equals(a, substituted, expected);
}

static void test_substitute(void) {
  struct solver_Algebra a; 
  solver_Algebra_init(&a);

  substitute(&a, "10", "10", "20", "20");
  substitute(&a, "10", "15", "20", "10");
  substitute(&a, "x + y + z", "x", "y + y", "3 * y + z");

  solver_Algebra_free(&a); 
}

void test_solver_algebra(void) {
  test_simple();
  test_substitute();

  printf("tests done\n");
}