static uint64_t embed_product(struct solver_Algebra *a, uint32_t count, const uint64_t *values);

static bool sum_merge_single_term(struct solver_Algebra *a, uint64_t lhs, uint64_t rhs, uint64_t *merged) {
  if(is_number(lhs) && is_number(rhs)) {
    *merged = number_add(a, lhs, rhs);
    return true;
  }

  // x + x => 2x
  if(is_variable(lhs) && is_variable(rhs) && alias_str_same(unpack_variable(lhs), unpack_variable(rhs))) {
    uint64_t product_values[2];
    product_values[0] = LITERAL_TWO;
    product_values[1] = lhs;
    *merged = embed_product(a, 2, product_values);
    return true;
  }

  // n • x + x => (n + 1) • x
  if(is_product(lhs) && get_list_count(lhs) == 2) {
    uint64_t num = get_index(a, lhs, 0);
    if(is_number(num)) {
      uint64_t var = get_index(a, lhs, 1);

      if(is_variable(var) && is_variable(rhs) && alias_str_same(unpack_variable(var), unpack_variable(rhs))) {
        uint64_t product_values[2];
        product_values[0] = number_add(a, num, LITERAL_ONE);
        product_values[1] = rhs;
        *merged = embed_product(a, 2, product_values);
        return true;
      }
    }
  }

  // x + n • x => (n + 1) • x
  if(is_product(rhs) && get_list_count(rhs) == 2) {
    uint64_t num = get_index(a, rhs, 0);
    if(is_number(num)) {
      uint64_t var = get_index(a, rhs, 1);

      if(is_variable(var) && is_variable(lhs) && alias_str_same(unpack_variable(var), unpack_variable(lhs))) {
        uint64_t product_values[2];
        product_values[0] = number_add(a, num, LITERAL_ONE);
        product_values[1] = lhs;
        *merged = embed_product(a, 2, product_values);
        return true;
      }
    }
  }

  return false;
}

static uint64_t embed_sum(struct solver_Algebra *a, uint32_t count, const uint64_t *values) {
  uint64_t merged;
  
  if(count == 0) {
    return LITERAL_ZERO;
  } else if(count == 1) {
    return values[0];
  } else if(count == 2) {
    if(sum_merge_single_term(a, values[0], values[1], &merged)) {
      return merged;
    }
  }

  // break open any immediate child sums and merge
  uint32_t total_count = 0;
  for(uint32_t i = 0; i < count; i++) {
    if(is_sum(values[i])) {
      total_count += get_list_count(values[i]);
    } else {
      total_count++;
    }
  }

  uint32_t new_count = 0;
  uint64_t * new_values = alias_stack_allocation(sizeof(*new_values) * total_count, alignof(*new_values));

  #define MERGE_INTO_NEW(TO_MERGE) do {                                \
    uint32_t k;                                                        \
    for(k = 0; k < new_count; k++) {                                   \
      if(sum_merge_single_term(a, new_values[k], TO_MERGE, &merged)) { \
        new_values[k] = merged;                                        \
        break;                                                         \
      }                                                                \
    }                                                                  \
    if(k == new_count) {                                               \
      new_values[new_count++] = TO_MERGE;                              \
    }                                                                  \
  } while(0)

  // for each item
  for(uint32_t i = 0; i < count; i++) {
    if(is_sum(values[i])) {
      uint32_t inner_count = get_list_count(values[i]);
      for(uint32_t j = 0; j < inner_count; j++) {
        uint64_t to_merge = get_index(a, values[i], j);
        MERGE_INTO_NEW(to_merge);
      }
    } else {
      MERGE_INTO_NEW(values[i]);
    }
  }

  if(new_count == 0) {
    return LITERAL_ZERO;
  } else if(new_count == 1) {
    return new_values[0];
  }

  tabula_qsort(new_values, new_count, sizeof(*new_values), qsort_term_compare, a);

  struct EncodedList sum;
  sum.count = new_count;
  sum.index = embed_values(a, new_count, new_values);
  return pack_sum(sum);
}

uint64_t solver_Algebra_sum(struct solver_Algebra *a, uint32_t count, const uint64_t *values) {
  return embed_sum(a, count, values);
}