{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reinforcement Learning\n",
    "\n",
    "Let's describe the \"taxi problem\". We want to build a self-driving taxi that can pick up passengers at one of a set of fixed locations, drop them off at another location, and get there in the quickest amount of time while avoiding obstacles.\n",
    "\n",
    "The AI Gym lets us create this environment quickly: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|R: | : :\u001b[34;1mG\u001b[0m|\n",
      "| : : : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|Y| :\u001b[43m \u001b[0m|\u001b[35mB\u001b[0m: |\n",
      "+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "import random\n",
    "\n",
    "random.seed(1234)\n",
    "\n",
    "streets = gym.make(\"Taxi-v3\").env #New versions keep getting released; if -v3 doesn't work, try -v2 or -v4\n",
    "streets.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's break down what we're seeing here:\n",
    "\n",
    "-  R, G, B, and Y are pickup or dropoff locations.\n",
    "-  The BLUE letter indicates where we need to pick someone up from.\n",
    "-  The MAGENTA letter indicates where that passenger wants to go to.\n",
    "-  The solid lines represent walls that the taxi cannot cross.\n",
    "-  The filled rectangle represents the taxi itself - it's yellow when empty, and green when carrying a passenger."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our little world here, which we've called \"streets\", is a 5x5 grid. The state of this world at any time can be defined by:\n",
    "\n",
    "-  Where the taxi is (one of 5x5 = 25 locations)\n",
    "-  What the current destination is (4 possibilities)\n",
    "-  Where the passenger is (5 possibilities: at one of the destinations, or inside the taxi)\n",
    "\n",
    "So there are a total of 25 x 4 x 5 = 500 possible states that describe our world.\n",
    "\n",
    "For each state, there are six possible actions:\n",
    "\n",
    "-  Move South, East, North, or West\n",
    "-  Pickup a passenger\n",
    "-  Drop off a passenger\n",
    "\n",
    "Q-Learning will take place using the following rewards and penalties at each state:\n",
    "\n",
    "-  A successfull drop-off yields +20 points\n",
    "-  Every time step taken while driving a passenger yields a -1 point penalty\n",
    "-  Picking up or dropping off at an illegal location yields a -10 point penalty\n",
    "\n",
    "Moving across a wall just isn't allowed at all.\n",
    "\n",
    "Let's define an initial state, with the taxi at location (2, 3), the passenger at pickup location 2, and the destination at location 0:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|\u001b[35mR\u001b[0m: | : :G|\n",
      "| : : : : |\n",
      "| : : :\u001b[43m \u001b[0m: |\n",
      "| | : | : |\n",
      "|\u001b[34;1mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "initial_state = streets.encode(2, 3, 2, 0)\n",
    "\n",
    "streets.s = initial_state\n",
    "\n",
    "streets.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's examine the reward table for this initial state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: [(1.0, 368, -1, False)],\n",
       " 1: [(1.0, 168, -1, False)],\n",
       " 2: [(1.0, 288, -1, False)],\n",
       " 3: [(1.0, 248, -1, False)],\n",
       " 4: [(1.0, 268, -10, False)],\n",
       " 5: [(1.0, 268, -10, False)]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "streets.P[initial_state]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's how to interpret this - each row corresponds to a potential action at this state: move South, North, East, or West, pickup, or dropoff. The four values in each row are the probability assigned to that action, the next state that results from that action, the reward for that action, and whether that action indicates a successful dropoff took place. \n",
    "\n",
    "So for example, moving North from this state would put us into state number 368, incur a penalty of -1 for taking up time, and does not result in a successful dropoff.\n",
    "\n",
    "So, let's do Q-learning! First we need to train our model. At a high level, we'll train over 10,000 simulated taxi runs. For each run, we'll step through time, with a 10% chance at each step of making a random, exploratory step instead of using the learned Q values to guide our actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "q_table = np.zeros([streets.observation_space.n, streets.action_space.n])\n",
    "\n",
    "learning_rate = 0.1\n",
    "discount_factor = 0.6\n",
    "exploration = 0.1\n",
    "epochs = 10000\n",
    "\n",
    "for taxi_run in range(epochs):\n",
    "    state = streets.reset()\n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        random_value = random.uniform(0, 1)\n",
    "        if (random_value < exploration):\n",
    "            action = streets.action_space.sample() # Explore a random action\n",
    "        else:\n",
    "            action = np.argmax(q_table[state]) # Use the action with the highest q-value\n",
    "            \n",
    "        next_state, reward, done, info = streets.step(action)\n",
    "        \n",
    "        prev_q = q_table[state, action]\n",
    "        next_max_q = np.max(q_table[next_state])\n",
    "        new_q = (1 - learning_rate) * prev_q + learning_rate * (reward + discount_factor * next_max_q)\n",
    "        q_table[state, action] = new_q\n",
    "        \n",
    "        state = next_state\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So now we have a table of Q-values that can be quickly used to determine the optimal next step for any given state! Let's check the table for our initial state above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.42319388, -2.4076029 , -2.40062912, -2.3639511 , -6.29837216,\n",
       "       -8.02552348])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q_table[initial_state]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The lowest q-value here corresponds to the action \"go West\", which makes sense - that's the most direct route toward our destination from that point. It seems to work! Let's see it in action!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trip number 10 Step 13\n",
      "+---------+\n",
      "|R: | : :G|\n",
      "| : : : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35m\u001b[34;1m\u001b[43mY\u001b[0m\u001b[0m\u001b[0m| : |B: |\n",
      "+---------+\n",
      "  (Dropoff)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from IPython.display import clear_output\n",
    "from time import sleep\n",
    "\n",
    "for tripnum in range(1, 11):\n",
    "    state = streets.reset()\n",
    "   \n",
    "    done = False\n",
    "    trip_length = 0\n",
    "    \n",
    "    while not done and trip_length < 25:\n",
    "        action = np.argmax(q_table[state])\n",
    "        next_state, reward, done, info = streets.step(action)\n",
    "        clear_output(wait=True)\n",
    "        print(\"Trip number \" + str(tripnum) + \" Step \" + str(trip_length))\n",
    "        print(streets.render(mode='ansi'))\n",
    "        sleep(.5)\n",
    "        state = next_state\n",
    "        trip_length += 1\n",
    "        \n",
    "    sleep(2)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Your Challenge\n",
    "\n",
    "Modify the block above to keep track of the total time steps, and use that as a metric as to how good our Q-learning system is. You might want to increase the number of simulated trips, and remove the sleep() calls to allow you to run over more samples.\n",
    "\n",
    "Now, try experimenting with the hyperparameters. How low can the number of epochs go before our model starts to suffer? Can you come up with better learning rates, discount factors, or exploration factors to make the training more efficient? The exploration vs. exploitation rate in particular is interesting to experiment with."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}