DATA_DIRECTORY = "Data"
STATE_TIME_SERIES = [
    "Data/2011_2016_2021/State/2021Census_T24A_AUST_STE.csv",
    "Data/2011_2016_2021/State/2021Census_T24B_AUST_STE.csv",
    "Data/2011_2016_2021/State/2021Census_T24C_AUST_STE.csv",
    "Data/2011_2016_2021/State/2021Census_T24D_AUST_STE.csv",
]


# Merge files into list of unique cells
def merge_time_series_files(filenames):
    merged_file = []
    for file_index in range(len(filenames)):
        file = open(filenames[file_index], "r").readlines()
        # Split the row into columns
        for row in range(len(file)):
            cols_in_row = file[row].split(",")

            # Remove any newlines from cells
            for column in range(len(cols_in_row)):
                cols_in_row[column] = cols_in_row[column].strip()

            file[row] = cols_in_row

        # The FIRST file in the list (DATA_A)
        if file_index == 0:
            # `merged_file` is empty, this is the first file we are processing
            merged_file = file
        else:
            for row in range(len(file)):
                # Extend with everything in row except state codes
                merged_file[row].extend(file[row][1:])

    return merged_file


def remove_quoted_commas(line):
    inside_quote = False
    cleaned_quote = ""
    for character in line:
        if character == '"':
            inside_quote = not inside_quote
        elif character == ",":
            if not inside_quote:
                cleaned_quote += character
        else:
            cleaned_quote += character

    # Make sure no quotes are unclosed
    assert not inside_quote

    return cleaned_quote


def extract_sections(relative_path, separator):
    file = open(f"{DATA_DIRECTORY}/{relative_path}", "r").read()

    # Split the file up into sections
    sections = []
    # Each section is separated by this string
    for section in file.split(separator):
        section = section.strip()
        # Some lines are just section breaks, so ignore those
        if section != "":
            sections.append(section)

    return sections


def tabulate(text):
    rows = text.splitlines()

    for row in range(len(rows)):
        rows[row] = rows[row].split(",")

        # Remove any extraneous whitespace
        for column in range(len(rows[row])):
            rows[row][column] = rows[row][column].strip()

    # TODO: assert table is rectangular
    return rows


def parse_financial_bracket(text):
    [start, end] = text.split("-")
    assert start.startswith("$")
    assert end.startswith("$")

    # Strip '$' prefix
    start = start[1:]
    end = end[1:]

    # Make sure we have filtered out bad categories such as 'Not stated'
    assert should_keep(start + end)

    # Downcast ranges from str -> int
    start = int(start)
    end = int(end)

    return (start, end)


# Simple function that says whether or not to use a category
def should_keep(category):
    # If any of these phrases are in the category, it's not allowed!
    bad_words = [
        "Neg_Ni_inc",  # Negative/nil income
        "more",  # Brackets over a certain threshold (eg $4000more)
        "or more",
        "NS",  # Not stated
        "Notstated",
        "Tot",  # Total
        "PI_S",
    ]
    for bad_word in bad_words:
        if bad_word in category:
            # Bad word! Tell them not to use this category
            return False

    # Made it to the end without any bad words, should be fine
    return True


# Take a valid category and return the year, mean income & rent
def split_category(category):
    # Example category: 'C11_1_149_R1_74'
    parts = category.split("_")
    # Make sure it's a valid category
    if not should_keep(category):
        print(f"Invalid category: {category}")
    assert should_keep(category) is True
    # All valid categories have 5 parts
    assert len(parts) == 5

    # Split up the parts
    year = parts[0]
    min_income = int(parts[1])
    max_income = int(parts[2])
    max_rent = int(parts[4])

    # Year starts with 'C' followed by 2 digits
    assert len(year) == 3
    assert year[0] == "C"
    assert year[1:].isdigit()

    # `min_rent` is special because it has the 'R' prefix
    min_rent = parts[3]
    # Make sure the first character is an 'R'
    assert min_rent[0] == "R"
    # Then we can just convert it into an int like the others
    min_rent = int(min_rent[1:])

    # Make sure we are always increasing
    assert max_rent > min_rent
    assert max_income > min_income

    mean_income = (min_income + max_income + 1) // 2
    mean_rent = (min_rent + max_rent + 1) // 2

    return (year, mean_income, mean_rent)


# Base dataset functionality, can be used to see broad historical trends from 1996-2021
class CensusDataset:
    def census_year(self):
        raise NotImplementedError

    def median_rents(self):
        raise NotImplementedError

    def median_incomes(self):
        raise NotImplementedError


class Census1996(CensusDataset):
    def __init__(self):
        sections = extract_sections("1996/1996_income_by_rent.csv", ",,,,,,,,,,,,,")

        self.WEEKLY_RENT_HEADING = remove_quoted_commas(sections[4])
        self.POPULATION_BY_INCOME = remove_quoted_commas(sections[5])

    def census_year(self):
        return 1996

    def filter_rent_brackets(self):
        assert self.WEEKLY_RENT_HEADING.startswith(",,,,,,     Weekly rent,,,,,,,")
        assert self.WEEKLY_RENT_HEADING.count("\n") == 2

        # The headings are split between 2 lines (using example $0-$99):
        # 1st line: start of range (eg $0-)
        # 2nd line: end of range (eg $99)
        [start_ranges, end_ranges] = tabulate(self.WEEKLY_RENT_HEADING)[1:]
        # Remove comma prefix
        start_ranges = start_ranges[1:]
        end_ranges = end_ranges[1:]

        # Make sure there is a matching number of cells
        assert len(end_ranges) == len(start_ranges)

        # Remove the last 3 columns:
        # 1. $1000 or more
        # 2. Not stated
        # 3. Total
        assert start_ranges[-3:] == ["$1000", "Not", ""]
        assert end_ranges[-3:] == ["or more", "stated", "Total"]

        start_ranges = start_ranges[:-3]
        end_ranges = end_ranges[:-3]

        brackets = []

        for bracket in range(len(start_ranges)):
            brackets.append(start_ranges[bracket] + end_ranges[bracket])

        return brackets

    def median_rents(self):
        brackets = self.filter_rent_brackets()
        # Extract the ranges and store as the middle point (eg $50 for $0-$99)
        rents = {}

        for bracket in range(len(brackets)):
            (start, end) = parse_financial_bracket(brackets[bracket])

            median_rent = (start + end + 1) // 2

            # Total population is sum of column
            total_population = 0
            for row in self.filter_income_brackets():
                total_population += int(row[bracket + 1])

            rents[median_rent] = total_population

        return rents

    def filter_income_brackets(self):
        rows = tabulate(self.POPULATION_BY_INCOME)

        # Remove rows containing irrelevant data
        assert rows[0][0] == "Negative income"
        assert rows[1][0] == "Nil income"
        assert rows[-3][0] == "$2000 or more"
        assert rows[-2][0] == "Partial income stated(a)"
        assert rows[-1][0] == "All incomes not stated(b)"

        return rows[2:-3]

    def median_incomes(self):
        incomes = {}

        for row in self.filter_income_brackets():
            (start, end) = parse_financial_bracket(row[0])
            median_rent = (start + end + 1) // 2

            # The last column is a 'total' amount
            population = int(row[-1])
            incomes[median_rent] = population

        return incomes

    def population_data(self):
        populations = []

        for line in self.filter_income_brackets():
            # We don't care about income brackets, only population data
            columns = self.extract_income_columns(line)[1:]
            columns_for_line = []

            for column in columns:
                columns_for_line.append(int(column))

            populations.append(columns_for_line)

        return populations


class Census2001(CensusDataset):
    def __init__(self):
        income_sections = extract_sections("2001/Income_2001.csv", ",,,")
        self.INCOME_DATA = remove_quoted_commas(income_sections[5])

        rent_sections = extract_sections("2001/Rent_2001.csv", ",,,,,")
        self.RENT_DATA = remove_quoted_commas(rent_sections[5])

    def census_year(self):
        return 2001

    def filtered_rent_brackets(self):
        rows = tabulate(self.RENT_DATA)

        assert rows[-2][0] == "$500 or more"
        assert rows[-1][0] == "Not stated"
        rows = rows[:-2]

        return rows

    def median_rents(self):
        rents = self.filtered_rent_brackets()
        median_rents = {}

        for rent in range(len(rents)):
            # Rent brackets are 2 dollar amounts separated by a hyphen ('-')
            # Example rent bracket (rents[rent]): $1-$49
            (start, end) = parse_financial_bracket(rents[rent][0])

            median_rent = (start + end + 1) // 2
            total_population = int(rents[rent][-1])

            median_rents[median_rent] = total_population

        return median_rents

    def filtered_income_brackets(self):
        rows = tabulate(self.INCOME_DATA)

        assert rows[0][0] == "Negative/Nil income"
        assert rows[-3][0] == "$2000 or more"
        assert rows[-2][0] == "Partial income stated(b)"
        assert rows[-1][0] == "All incomes not stated(c)"
        rows = rows[1:-3]

        return rows

    def median_incomes(self):
        incomes = self.filtered_income_brackets()
        median_incomes = {}

        for income in range(len(incomes)):
            (start, end) = parse_financial_bracket(incomes[income][0])

            median_income = (start + end + 1) // 2
            total_population = int(incomes[income][-1])

            median_incomes[median_income] = total_population

        return median_incomes


class TimeSeriesBase(CensusDataset):
    def __init__(self, dataset):
        self.categories = dataset[0]
        self.data = dataset[1:]

    def median_incomes(self):
        incomes = {}
        for column in range(len(self.categories)):
            category = self.categories[column]

            total_income = 0
            for row in range(len(self.data)):
                total_income += int(self.data[row][column])

            income = category[0]
            if income not in incomes:
                incomes[income] = 0
            incomes[income] += total_income

        return incomes

    def median_rents(self):
        rents = {}
        for column in range(len(self.categories)):
            category = self.categories[column]

            total_rent = 0
            for row in range(len(self.data)):
                total_rent += int(self.data[row][column])

            rent = category[1]
            if rent not in rents:
                rents[rent] = 0
            rents[rent] += total_rent

        return rents


class Census2011(TimeSeriesBase):
    def year(self):
        return 2011


class Census2016(TimeSeriesBase):
    def year(self):
        return 2016


class Census2021(TimeSeriesBase):
    def year(self):
        return 2021


class Census2011_2016_2021(CensusDataset):
    def __init__(self, files):
        census2011 = None
        census2016 = None
        census2021 = None

        for file_index in range(len(files)):
            file = open(files[file_index], "r").read()
            file = tabulate(file)

            # First-time setup requires opening the first file
            if file_index == 0:
                census2011 = [[] for row in range(len(file))]
                census2016 = [[] for row in range(len(file))]
                census2021 = [[] for row in range(len(file))]

            # Headings are the first row
            headings = file[0]
            # Make sure to skip the first column (STE_CODE_2021)
            for column in range(1, len(file[0])):
                if should_keep(headings[column]):
                    (year, mean_income, mean_rent) = split_category(headings[column])

                    # Add each cell to its corresponding row
                    for row in range(len(file)):
                        if row == 0:
                            cell = (mean_income, mean_rent)
                        else:
                            cell = file[row][column]

                        if year == "C11":
                            census2011[row].append(cell)
                        elif year == "C16":
                            census2016[row].append(cell)
                        elif year == "C21":
                            census2021[row].append(cell)
                        else:
                            raise ValueError("Unsupported census year")

        # Make sure headings are consistent across years
        assert census2011[0] == census2016[0] == census2021[0]

        self.census2011 = Census2011(census2011)
        self.census2016 = Census2016(census2016)
        self.census2021 = Census2021(census2021)

    def get_year(self, year):
        if year == 2011:
            return self.census2011
        elif year == 2016:
            return self.census2016
        elif year == 2021:
            return self.census2021
        else:
            raise ValueError("Unsupported year")


def find_median(data_points):
    # Data will be mapped category: population
    # For example:
    """
    {
        20: 2324, # $20 income, 2324 people in category
        60: 4108, # $60 income, 4108 people in category
    }
    """
    # List of all categories
    keys = []
    total_population = 0
    for key in data_points.keys():
        keys.append(key)
        total_population += data_points[key]

    # Sort in ascending order (needed to find median)
    keys = sorted(keys)

    # An even total population means we must calculate the median ourselves
    if total_population % 2 == 0:
        # Median will be midpoint between these two numbers
        # In most cases they'll probably be equal
        left_median = (total_population // 2) - 1
        right_median = (total_population // 2) + 1

        population_seen = 0

        median = [None, None]
        for key in keys:
            if population_seen <= left_median <= population_seen + data_points[key]:
                median[0] = key

            if population_seen <= right_median <= population_seen + data_points[key]:
                median[1] = key

            # If we've found both medians, no more work to do
            if median[0] is not None and median[1] is not None:
                break

            population_seen += data_points[key]

        # Make sure we actually found the relevant points
        assert median[0] is not None
        assert median[1] is not None

        # The mean between two middle points is the median
        return round(sum(median) / 2)
    else:
        median_person = (total_population + 1) // 2
        population_seen = 0

        for key in keys:
            if population_seen <= median_person <= population_seen + data_points[key]:
                return key
            population_seen += data_points[key]


census_years = [
    Census1996(),
    Census2001(),
]

time_series = Census2011_2016_2021(STATE_TIME_SERIES)
census_years.append(time_series.get_year(2011))
census_years.append(time_series.get_year(2016))
census_years.append(time_series.get_year(2021))

for census_year in census_years:
    median_income = find_median(census_year.median_incomes())
    median_rent = find_median(census_year.median_rents())

    print(median_income, median_rent, round(median_rent / median_income * 100))