{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "IOpw--87wvkW"
},
"source": [
"Instead of using fine-tuning, we'll use RAG to build our own \"Commander Data\" based on everything he ever said in the scripts.\n",
"\n",
"To summarize the high level approach:\n",
"\n",
"- We'll first parse all of the scripts to extract every line from Data, much as we did in the fine-tuning example.\n",
"- Then we'll use the OpenAI embeddings API to compute embedding vectors for every one of his lines. This basically gives us similarity measures between every line.\n",
"- RAG calls for use of a vector database to store these lines with the associated embedding vectors. To keep things simple, we'll use a local database called chromadb. There are plenty of cloud-based vector database services out there as well.\n",
"- Then we'll make a little retrieval function that retrieves the N most-similar lines from the vector database for a given query\n",
"- Those similar lines are then added as context to the prompt before it is handed off the the chat API.\n",
"\n",
"I'm intentionally not using langchain or some other higher-level framework, because this is actually pretty simple without it.\n",
"\n",
"First, let's install chromadb:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2J2I0mlbLWba",
"outputId": "09b21dcd-a1b5-45e7-a057-ac4b36b61da8"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting chromadb\n",
" Downloading chromadb-1.3.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.2 kB)\n",
"Collecting build>=1.0.3 (from chromadb)\n",
" Downloading build-1.3.0-py3-none-any.whl.metadata (5.6 kB)\n",
"Requirement already satisfied: pydantic>=1.9 in /usr/local/lib/python3.12/dist-packages (from chromadb) (2.12.3)\n",
"Collecting pybase64>=1.4.1 (from chromadb)\n",
" Downloading pybase64-1.4.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (8.7 kB)\n",
"Requirement already satisfied: uvicorn>=0.18.3 in /usr/local/lib/python3.12/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.38.0)\n",
"Requirement already satisfied: numpy>=1.22.5 in /usr/local/lib/python3.12/dist-packages (from chromadb) (2.0.2)\n",
"Collecting posthog<6.0.0,>=2.4.0 (from chromadb)\n",
" Downloading posthog-5.4.0-py3-none-any.whl.metadata (5.7 kB)\n",
"Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (4.15.0)\n",
"Collecting onnxruntime>=1.14.1 (from chromadb)\n",
" Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)\n",
"Requirement already satisfied: opentelemetry-api>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (1.37.0)\n",
"Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)\n",
" Downloading opentelemetry_exporter_otlp_proto_grpc-1.38.0-py3-none-any.whl.metadata (2.4 kB)\n",
"Requirement already satisfied: opentelemetry-sdk>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (1.37.0)\n",
"Requirement already satisfied: tokenizers>=0.13.2 in /usr/local/lib/python3.12/dist-packages (from chromadb) (0.22.1)\n",
"Collecting pypika>=0.48.9 (from chromadb)\n",
" Downloading PyPika-0.48.9.tar.gz (67 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.3/67.3 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: tqdm>=4.65.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (4.67.1)\n",
"Requirement already satisfied: overrides>=7.3.1 in /usr/local/lib/python3.12/dist-packages (from chromadb) (7.7.0)\n",
"Requirement already satisfied: importlib-resources in /usr/local/lib/python3.12/dist-packages (from chromadb) (6.5.2)\n",
"Requirement already satisfied: grpcio>=1.58.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (1.76.0)\n",
"Collecting bcrypt>=4.0.1 (from chromadb)\n",
" Downloading bcrypt-5.0.0-cp39-abi3-manylinux_2_34_x86_64.whl.metadata (10 kB)\n",
"Requirement already satisfied: typer>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (0.20.0)\n",
"Collecting kubernetes>=28.1.0 (from chromadb)\n",
" Downloading kubernetes-34.1.0-py2.py3-none-any.whl.metadata (1.7 kB)\n",
"Requirement already satisfied: tenacity>=8.2.3 in /usr/local/lib/python3.12/dist-packages (from chromadb) (9.1.2)\n",
"Requirement already satisfied: pyyaml>=6.0.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (6.0.3)\n",
"Collecting mmh3>=4.0.1 (from chromadb)\n",
" Downloading mmh3-5.2.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (14 kB)\n",
"Requirement already satisfied: orjson>=3.9.12 in /usr/local/lib/python3.12/dist-packages (from chromadb) (3.11.4)\n",
"Requirement already satisfied: httpx>=0.27.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (0.28.1)\n",
"Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (13.9.4)\n",
"Requirement already satisfied: jsonschema>=4.19.0 in /usr/local/lib/python3.12/dist-packages (from chromadb) (4.25.1)\n",
"Requirement already satisfied: packaging>=19.1 in /usr/local/lib/python3.12/dist-packages (from build>=1.0.3->chromadb) (25.0)\n",
"Collecting pyproject_hooks (from build>=1.0.3->chromadb)\n",
" Downloading pyproject_hooks-1.2.0-py3-none-any.whl.metadata (1.3 kB)\n",
"Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx>=0.27.0->chromadb) (4.11.0)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx>=0.27.0->chromadb) (2025.11.12)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx>=0.27.0->chromadb) (1.0.9)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx>=0.27.0->chromadb) (3.11)\n",
"Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx>=0.27.0->chromadb) (0.16.0)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.19.0->chromadb) (25.4.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.19.0->chromadb) (2025.9.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.19.0->chromadb) (0.37.0)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.19.0->chromadb) (0.29.0)\n",
"Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (1.17.0)\n",
"Requirement already satisfied: python-dateutil>=2.5.3 in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (2.9.0.post0)\n",
"Requirement already satisfied: google-auth>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (2.43.0)\n",
"Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (1.9.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (2.32.4)\n",
"Requirement already satisfied: requests-oauthlib in /usr/local/lib/python3.12/dist-packages (from kubernetes>=28.1.0->chromadb) (2.0.0)\n",
"Collecting urllib3<2.4.0,>=1.24.2 (from kubernetes>=28.1.0->chromadb)\n",
" Downloading urllib3-2.3.0-py3-none-any.whl.metadata (6.5 kB)\n",
"Collecting durationpy>=0.7 (from kubernetes>=28.1.0->chromadb)\n",
" Downloading durationpy-0.10-py3-none-any.whl.metadata (340 bytes)\n",
"Collecting coloredlogs (from onnxruntime>=1.14.1->chromadb)\n",
" Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n",
"Requirement already satisfied: flatbuffers in /usr/local/lib/python3.12/dist-packages (from onnxruntime>=1.14.1->chromadb) (25.9.23)\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.12/dist-packages (from onnxruntime>=1.14.1->chromadb) (5.29.5)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.12/dist-packages (from onnxruntime>=1.14.1->chromadb) (1.14.0)\n",
"Requirement already satisfied: importlib-metadata<8.8.0,>=6.0 in /usr/local/lib/python3.12/dist-packages (from opentelemetry-api>=1.2.0->chromadb) (8.7.0)\n",
"Requirement already satisfied: googleapis-common-protos~=1.57 in /usr/local/lib/python3.12/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.72.0)\n",
"Collecting opentelemetry-exporter-otlp-proto-common==1.38.0 (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb)\n",
" Downloading opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl.metadata (1.8 kB)\n",
"Collecting opentelemetry-proto==1.38.0 (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb)\n",
" Downloading opentelemetry_proto-1.38.0-py3-none-any.whl.metadata (2.3 kB)\n",
"Collecting opentelemetry-sdk>=1.2.0 (from chromadb)\n",
" Downloading opentelemetry_sdk-1.38.0-py3-none-any.whl.metadata (1.5 kB)\n",
"Collecting opentelemetry-api>=1.2.0 (from chromadb)\n",
" Downloading opentelemetry_api-1.38.0-py3-none-any.whl.metadata (1.5 kB)\n",
"Collecting opentelemetry-semantic-conventions==0.59b0 (from opentelemetry-sdk>=1.2.0->chromadb)\n",
" Downloading opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl.metadata (2.4 kB)\n",
"Collecting backoff>=1.10.0 (from posthog<6.0.0,>=2.4.0->chromadb)\n",
" Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)\n",
"Requirement already satisfied: distro>=1.5.0 in /usr/local/lib/python3.12/dist-packages (from posthog<6.0.0,>=2.4.0->chromadb) (1.9.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.9->chromadb) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.9->chromadb) (2.41.4)\n",
"Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.9->chromadb) (0.4.2)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->chromadb) (4.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->chromadb) (2.19.2)\n",
"Requirement already satisfied: huggingface-hub<2.0,>=0.16.4 in /usr/local/lib/python3.12/dist-packages (from tokenizers>=0.13.2->chromadb) (0.36.0)\n",
"Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.9.0->chromadb) (8.3.1)\n",
"Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.9.0->chromadb) (1.5.4)\n",
"Collecting httptools>=0.6.3 (from uvicorn[standard]>=0.18.3->chromadb)\n",
" Downloading httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (3.5 kB)\n",
"Requirement already satisfied: python-dotenv>=0.13 in /usr/local/lib/python3.12/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (1.2.1)\n",
"Collecting uvloop>=0.15.1 (from uvicorn[standard]>=0.18.3->chromadb)\n",
" Downloading uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)\n",
"Collecting watchfiles>=0.13 (from uvicorn[standard]>=0.18.3->chromadb)\n",
" Downloading watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\n",
"Requirement already satisfied: websockets>=10.4 in /usr/local/lib/python3.12/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (15.0.1)\n",
"Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (6.2.2)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (0.4.2)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (4.9.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.16.4->tokenizers>=0.13.2->chromadb) (3.20.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.16.4->tokenizers>=0.13.2->chromadb) (2025.3.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.16.4->tokenizers>=0.13.2->chromadb) (1.2.0)\n",
"Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib-metadata<8.8.0,>=6.0->opentelemetry-api>=1.2.0->chromadb) (3.23.0)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->chromadb) (0.1.2)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->kubernetes>=28.1.0->chromadb) (3.4.4)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio->httpx>=0.27.0->chromadb) (1.3.1)\n",
"Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime>=1.14.1->chromadb)\n",
" Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.12/dist-packages (from requests-oauthlib->kubernetes>=28.1.0->chromadb) (3.3.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy->onnxruntime>=1.14.1->chromadb) (1.3.0)\n",
"Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (0.6.1)\n",
"Downloading chromadb-1.3.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.4/21.4 MB\u001b[0m \u001b[31m56.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading bcrypt-5.0.0-cp39-abi3-manylinux_2_34_x86_64.whl (278 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m278.2/278.2 kB\u001b[0m \u001b[31m20.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading build-1.3.0-py3-none-any.whl (23 kB)\n",
"Downloading kubernetes-34.1.0-py2.py3-none-any.whl (2.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m76.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading mmh3-5.2.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (103 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.3/103.3 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.4/17.4 MB\u001b[0m \u001b[31m90.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading opentelemetry_exporter_otlp_proto_grpc-1.38.0-py3-none-any.whl (19 kB)\n",
"Downloading opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl (18 kB)\n",
"Downloading opentelemetry_proto-1.38.0-py3-none-any.whl (72 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.5/72.5 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading opentelemetry_sdk-1.38.0-py3-none-any.whl (132 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.3/132.3 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading opentelemetry_api-1.38.0-py3-none-any.whl (65 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.9/65.9 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl (207 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.0/208.0 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading posthog-5.4.0-py3-none-any.whl (105 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.4/105.4 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading pybase64-1.4.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl (71 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.6/71.6 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading backoff-2.2.1-py3-none-any.whl (15 kB)\n",
"Downloading durationpy-0.10-py3-none-any.whl (3.9 kB)\n",
"Downloading httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (517 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m517.7/517.7 kB\u001b[0m \u001b[31m33.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading urllib3-2.3.0-py3-none-any.whl (128 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.4/128.4 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (4.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m86.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (456 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m456.8/456.8 kB\u001b[0m \u001b[31m24.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading pyproject_hooks-1.2.0-py3-none-any.whl (10 kB)\n",
"Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hBuilding wheels for collected packages: pypika\n",
" Building wheel for pypika (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pypika: filename=pypika-0.48.9-py2.py3-none-any.whl size=53803 sha256=72bf0ed142a9e710e90663641f637d5cc1404bd9cc44c02d092b2cc759cfff01\n",
" Stored in directory: /root/.cache/pip/wheels/d5/3d/69/8d68d249cd3de2584f226e27fd431d6344f7d70fd856ebd01b\n",
"Successfully built pypika\n",
"Installing collected packages: pypika, durationpy, uvloop, urllib3, pyproject_hooks, pybase64, opentelemetry-proto, mmh3, humanfriendly, httptools, bcrypt, backoff, watchfiles, opentelemetry-exporter-otlp-proto-common, opentelemetry-api, coloredlogs, build, posthog, opentelemetry-semantic-conventions, onnxruntime, opentelemetry-sdk, kubernetes, opentelemetry-exporter-otlp-proto-grpc, chromadb\n",
" Attempting uninstall: urllib3\n",
" Found existing installation: urllib3 2.5.0\n",
" Uninstalling urllib3-2.5.0:\n",
" Successfully uninstalled urllib3-2.5.0\n",
" Attempting uninstall: opentelemetry-proto\n",
" Found existing installation: opentelemetry-proto 1.37.0\n",
" Uninstalling opentelemetry-proto-1.37.0:\n",
" Successfully uninstalled opentelemetry-proto-1.37.0\n",
" Attempting uninstall: opentelemetry-exporter-otlp-proto-common\n",
" Found existing installation: opentelemetry-exporter-otlp-proto-common 1.37.0\n",
" Uninstalling opentelemetry-exporter-otlp-proto-common-1.37.0:\n",
" Successfully uninstalled opentelemetry-exporter-otlp-proto-common-1.37.0\n",
" Attempting uninstall: opentelemetry-api\n",
" Found existing installation: opentelemetry-api 1.37.0\n",
" Uninstalling opentelemetry-api-1.37.0:\n",
" Successfully uninstalled opentelemetry-api-1.37.0\n",
" Attempting uninstall: opentelemetry-semantic-conventions\n",
" Found existing installation: opentelemetry-semantic-conventions 0.58b0\n",
" Uninstalling opentelemetry-semantic-conventions-0.58b0:\n",
" Successfully uninstalled opentelemetry-semantic-conventions-0.58b0\n",
" Attempting uninstall: opentelemetry-sdk\n",
" Found existing installation: opentelemetry-sdk 1.37.0\n",
" Uninstalling opentelemetry-sdk-1.37.0:\n",
" Successfully uninstalled opentelemetry-sdk-1.37.0\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"google-adk 1.19.0 requires opentelemetry-api<=1.37.0,>=1.37.0, but you have opentelemetry-api 1.38.0 which is incompatible.\n",
"google-adk 1.19.0 requires opentelemetry-sdk<=1.37.0,>=1.37.0, but you have opentelemetry-sdk 1.38.0 which is incompatible.\n",
"opentelemetry-exporter-otlp-proto-http 1.37.0 requires opentelemetry-exporter-otlp-proto-common==1.37.0, but you have opentelemetry-exporter-otlp-proto-common 1.38.0 which is incompatible.\n",
"opentelemetry-exporter-otlp-proto-http 1.37.0 requires opentelemetry-proto==1.37.0, but you have opentelemetry-proto 1.38.0 which is incompatible.\n",
"opentelemetry-exporter-otlp-proto-http 1.37.0 requires opentelemetry-sdk~=1.37.0, but you have opentelemetry-sdk 1.38.0 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0mSuccessfully installed backoff-2.2.1 bcrypt-5.0.0 build-1.3.0 chromadb-1.3.5 coloredlogs-15.0.1 durationpy-0.10 httptools-0.7.1 humanfriendly-10.0 kubernetes-34.1.0 mmh3-5.2.0 onnxruntime-1.23.2 opentelemetry-api-1.38.0 opentelemetry-exporter-otlp-proto-common-1.38.0 opentelemetry-exporter-otlp-proto-grpc-1.38.0 opentelemetry-proto-1.38.0 opentelemetry-sdk-1.38.0 opentelemetry-semantic-conventions-0.59b0 posthog-5.4.0 pybase64-1.4.2 pypika-0.48.9 pyproject_hooks-1.2.0 urllib3-2.3.0 uvloop-0.22.1 watchfiles-1.1.1\n"
]
}
],
"source": [
"!pip install chromadb\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L4_ysgbQx-jF"
},
"source": [
"Now we'll parse out all of the scripts and extract every line of dialog from \"DATA\". This is almost exactly the same code as from our fine tuning example's preprocessing script. Note you will need to upload all of the script files into a tng folder within your sample_data folder in your CoLab workspace first.\n",
"\n",
"An archive can be found at https://www.st-minutiae.com/resources/scripts/ (look for \"All TNG Epsiodes\"), but you could easily adapt this to read scripts from your favorite character from your favorite TV show or movie instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dLjupJ8rLXr6"
},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"import random\n",
"\n",
"dialogues = []\n",
"\n",
"def strip_parentheses(s):\n",
" return re.sub(r'\\(.*?\\)', '', s)\n",
"\n",
"def is_single_word_all_caps(s):\n",
" # First, we split the string into words\n",
" words = s.split()\n",
"\n",
" # Check if the string contains only a single word\n",
" if len(words) != 1:\n",
" return False\n",
"\n",
" # Make sure it isn't a line number\n",
" if bool(re.search(r'\\d', words[0])):\n",
" return False\n",
"\n",
" # Check if the single word is in all caps\n",
" return words[0].isupper()\n",
"\n",
"def extract_character_lines(file_path, character_name):\n",
" lines = []\n",
" with open(file_path, 'r') as script_file:\n",
" try:\n",
" lines = script_file.readlines()\n",
" except UnicodeDecodeError:\n",
" pass\n",
"\n",
" is_character_line = False\n",
" current_line = ''\n",
" current_character = ''\n",
" for line in lines:\n",
" strippedLine = line.strip()\n",
" if (is_single_word_all_caps(strippedLine)):\n",
" is_character_line = True\n",
" current_character = strippedLine\n",
" elif (line.strip() == '') and is_character_line:\n",
" is_character_line = False\n",
" dialog_line = strip_parentheses(current_line).strip()\n",
" dialog_line = dialog_line.replace('\"', \"'\")\n",
" if (current_character == 'DATA' and len(dialog_line)>0):\n",
" dialogues.append(dialog_line)\n",
" current_line = ''\n",
" elif is_character_line:\n",
" current_line += line.strip() + ' '\n",
"\n",
"def process_directory(directory_path, character_name):\n",
" for filename in os.listdir(directory_path):\n",
" file_path = os.path.join(directory_path, filename)\n",
" if os.path.isfile(file_path): # Ignore directories\n",
" extract_character_lines(file_path, character_name)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yh3peBSuMiyx"
},
"outputs": [],
"source": [
"process_directory(\"./sample_data/tng\", 'DATA')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jqjo_OdEyeNF"
},
"source": [
"Let's do a little sanity check to make sure the lines imported correctly, and print out the first one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x5dPI-29Myxo",
"outputId": "a013370c-db01-413e-9cdc-9f45693d285e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The enemy vessel is firing.\n"
]
}
],
"source": [
"print (dialogues[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7x4s4hOnyqIr"
},
"source": [
"Now we'll define the dimensionality for our embedding vectors that will be stored in Chroma.\n",
"\n",
"Chroma will simply store:\n",
"\n",
"- An **ID** for each line of dialog\n",
"- The **text** of the line\n",
"- The **embedding vector** (a list of floats)\n",
"\n",
"We don't need `docarray` or custom document classes for this approach — plain Python lists and strings are enough.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S15N5wi2QMvn"
},
"outputs": [],
"source": [
"# Dimensionality of the embeddings we will request from OpenAI.\n",
"# This must match the `dimensions` argument we pass when creating embeddings.\n",
"embedding_dimensions = 128\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8W1jJUxty9kk"
},
"source": [
"It's time to start computing embeddings for each line in OpenAI, so let's make sure OpenAI is installed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4Sqtgr4kQQG9",
"outputId": "eb8b236a-e08c-4463-b259-8ca9ea848e1d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: openai in /usr/local/lib/python3.12/dist-packages (2.8.1)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.12/dist-packages (from openai) (4.11.0)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.12/dist-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from openai) (0.28.1)\n",
"Requirement already satisfied: jiter<1,>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from openai) (0.12.0)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.12/dist-packages (from openai) (2.12.3)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.12/dist-packages (from openai) (1.3.1)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.12/dist-packages (from openai) (4.67.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.11 in /usr/local/lib/python3.12/dist-packages (from openai) (4.15.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.12/dist-packages (from anyio<5,>=3.5.0->openai) (3.11)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->openai) (2025.11.12)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->openai) (1.0.9)\n",
"Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.16.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3,>=1.9.0->openai) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic<3,>=1.9.0->openai) (2.41.4)\n",
"Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3,>=1.9.0->openai) (0.4.2)\n"
]
}
],
"source": [
"!pip install openai --upgrade"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y8oMwMzOzFLD"
},
"source": [
"Let's initialize the OpenAI client, and test creating an embedding for a single line of dialog just to make sure it works.\n",
"\n",
"You will need to provide your own OpenAI secret key here. To use this code as-is, click on the little key icon in CoLab and add a \"secret\" for OPENAI_API_KEY that points to your secret key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4kJhWeYySMWL",
"outputId": "8d18a2f9-4c14-4017-f6e4-6f1a0fe3062c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[0.09909163415431976, 0.2058548778295517, 0.08390823751688004, -0.039756521582603455, -0.13281475007534027, -0.02327454648911953, -0.06748619675636292, 0.1062837690114975, -0.04762791469693184, -0.1171518862247467, 0.07999251782894135, 0.037498991936445236, -0.08670517802238464, -0.07104230672121048, -0.00662275729700923, 0.042593419551849365, -0.004672390408813953, -0.10812176018953323, -0.04347245767712593, 0.08454754203557968, 0.0033638214226812124, 0.07659623771905899, -0.019129080697894096, -0.04698861390352249, 0.06756611168384552, -0.07835431396961212, 0.14168505370616913, -0.02099703811109066, 0.055099744349718094, 0.004777275491505861, 0.19195008277893066, -0.07040300965309143, 0.036779776215553284, -0.05905541777610779, 0.13904793560504913, -0.013565164990723133, 0.009419698268175125, 0.07575715333223343, -0.08814360946416855, 0.03544124215841293, 0.0470685251057148, -0.015043548308312893, -0.08942221105098724, -0.017321057617664337, -0.059295155107975006, 0.03811831399798393, -0.08686500787734985, 0.09142002463340759, 0.09669425338506699, 0.08670517802238464, -0.01878945156931877, 0.024153586477041245, -0.18651603162288666, 0.02111690677702427, -0.13233527541160583, -0.07739536464214325, 0.04363228380680084, -0.0390373095870018, 0.03160543739795685, 0.14056627452373505, 0.013035744428634644, 0.03831809386610985, 0.049465905874967575, -0.07200126349925995, -0.04590979218482971, -0.0249127559363842, -0.009324802085757256, 0.05705760046839714, 0.002939285710453987, 0.12018856406211853, -0.047308262437582016, -0.011377557180821896, -0.11083878576755524, -0.03160543739795685, 0.1440824270248413, -0.13928768038749695, 0.020347747951745987, -0.01386483758687973, 0.06876479834318161, 0.07032309472560883, 0.06844514608383179, -0.12849947810173035, -0.025891685858368874, -0.04119494929909706, -0.0022462934721261263, -0.1388082057237625, -0.2573185861110687, 0.050944287329912186, -0.08406806737184525, 0.06804558634757996, -0.12442392855882645, 0.13153615593910217, 0.039616674184799194, -0.029767446219921112, -0.011387546546757221, -0.1280200034379959, -0.11403529345989227, 0.0089252395555377, 0.05438052862882614, 0.005314188543707132, 0.12282568216323853, -0.03675980120897293, 0.05829624831676483, -0.20137977600097656, 0.05801655352115631, -0.1600649505853653, 0.029767446219921112, -0.06752615422010422, 0.07687593251466751, -0.024033715948462486, -0.0497855544090271, 0.060773540288209915, 0.147918239235878, 0.03905728831887245, -0.06676698476076126, -0.10452569276094437, 0.1197889968752861, 0.08662527054548264, 0.020397692918777466, -0.04802747815847397, -0.013225536793470383, 0.053501490503549576, 0.019618544727563858, 0.07335977256298065, 0.04095521196722984, 0.024553148075938225, 0.3349936604499817, -0.0942968800663948]\n"
]
}
],
"source": [
"from google.colab import userdata\n",
"\n",
"from openai import OpenAI\n",
"client = OpenAI(api_key=userdata.get('OPENAI_API_KEY'))\n",
"\n",
"embedding_model = \"text-embedding-3-small\"\n",
"\n",
"response = client.embeddings.create(\n",
" input=dialogues[1],\n",
" dimensions=embedding_dimensions,\n",
" model= embedding_model\n",
")\n",
"\n",
"print(response.data[0].embedding)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NQVaAlcwzeCM"
},
"source": [
"Let's double check that we do in fact have embeddings of 128 dimensions as we specified."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PtXBxhTmShNs",
"outputId": "1a3891e5-3b4c-4ec3-b306-5585681d9fc5"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"128\n"
]
}
],
"source": [
"print(len(response.data[0].embedding))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mRk4NnRVzjDz"
},
"source": [
"OK, now let's compute embeddings for every line Data ever said. The OpenAI API currently can't handle computing them all at once, so we're breaking it up into 128 lines at a time here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TyFFTJmGcdlt"
},
"outputs": [],
"source": [
"#Generate embeddings for everything Data ever said, 128 lines at a time.\n",
"embeddings = []\n",
"\n",
"for i in range(0, len(dialogues), 128):\n",
" dialog_slice = dialogues[i:i+128]\n",
" slice_embeddings = client.embeddings.create(\n",
" input=dialog_slice,\n",
" dimensions=embedding_dimensions,\n",
" model=embedding_model\n",
" )\n",
"\n",
" embeddings.extend(slice_embeddings.data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aA5XuIUZzs_b"
},
"source": [
"Let's check how many embeddings we actually got back in total."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kuBDo7PUsVJZ",
"outputId": "41dd6968-8efa-4821-deef-ce3e909a7f79"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"6502\n"
]
}
],
"source": [
"print (len(embeddings))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sKrU5tvcz1Vq"
},
"source": [
"Now let's insert every line and its embedding vector into our vector database.\n",
"\n",
"We'll use **Chroma**, a lightweight local vector store that:\n",
"\n",
"- Runs entirely in your notebook process\n",
"- Requires no external service or account\n",
"- Can persist data to a local directory\n",
"\n",
"We'll create a persistent collection under `./sample_data/chroma_db` and store each dialog line along with its embedding and an integer ID.\n",
"\n",
"Be sure to create that sample_data/chroma_db directory first!\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vVLSnCO7bRo1",
"outputId": "6df300c1-d717-4390-9540-90c7d12dd6fc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Inserted 6502 dialog lines into Chroma in batches of 1000.\n"
]
}
],
"source": [
"import chromadb\n",
"from chromadb.config import Settings\n",
"\n",
"# Create (or connect to) a local Chroma DB on disk\n",
"chroma_client = chromadb.PersistentClient(\n",
" path=\"./sample_data/chroma_db\"\n",
")\n",
"\n",
"# Create (or get) a collection for Data's dialog lines\n",
"collection = chroma_client.get_or_create_collection(\n",
" name=\"data_dialogues\"\n",
")\n",
"\n",
"ids = [str(i) for i in range(len(dialogues))]\n",
"docs = dialogues\n",
"embs = [e.embedding for e in embeddings]\n",
"\n",
"# Chroma has a max batch size; stay safely under it\n",
"BATCH_SIZE = 1000 # you can bump this up, as long as < 5461\n",
"\n",
"for start in range(0, len(docs), BATCH_SIZE):\n",
" end = start + BATCH_SIZE\n",
" batch_ids = ids[start:end]\n",
" batch_docs = docs[start:end]\n",
" batch_embs = embs[start:end]\n",
"\n",
" collection.add(\n",
" ids=batch_ids,\n",
" documents=batch_docs,\n",
" embeddings=batch_embs\n",
" )\n",
"\n",
"print(f\"Inserted {len(docs)} dialog lines into Chroma in batches of {BATCH_SIZE}.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fxormngx0IFz"
},
"source": [
"Let's try querying our vector database for lines similar to a query string.\n",
"\n",
"First we need to compute the embedding vector for our query string, then we'll query the vector database for the top 10 matches based on the similarities encoded by their embedding vectors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "woU_o_qS0FuD",
"outputId": "68eee3dc-68d8-4747-8578-57442d17b86d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Most similar lines to: Lal, my daughter\n",
"--------------------------------------------------\n",
"That is Lal, my daughter.\n",
"Lal...\n",
"What do you feel, Lal?\n",
"Yes, Lal. I am here.\n",
"Correct, Lal. We are a family.\n",
"No, Lal, this is a flower.\n",
"Lal, you used a verbal contraction.\n",
"Yes, Wesley. Lal is my child.\n",
"Lal's creation is entirely dependent on me. I am giving it knowledge and skills that are stored in my brain... its programming reflects mine in the same way a human child's genes reflect its parent's genes...\n",
"Perhaps. I created Lal because I wished to procreate. Despite what happened to her, I still have that wish.\n"
]
}
],
"source": [
"# Perform a search query using Chroma\n",
"queryText = 'Lal, my daughter'\n",
"\n",
"# Create an embedding for the query text\n",
"response = client.embeddings.create(\n",
" input=queryText,\n",
" dimensions=embedding_dimensions,\n",
" model=embedding_model\n",
")\n",
"query_embedding = response.data[0].embedding\n",
"\n",
"# Query Chroma for the 10 most similar dialog lines\n",
"results = collection.query(\n",
" query_embeddings=[query_embedding],\n",
" n_results=10\n",
")\n",
"\n",
"print(\"Most similar lines to:\", queryText)\n",
"print(\"--------------------------------------------------\")\n",
"for line in results[\"documents\"][0]:\n",
" print(line)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tskipUR20mYD"
},
"source": [
"Let's put it all together! We'll write a generate_response function that:\n",
"\n",
"- Computes an embedding for the query passed in\n",
"- Queries our vector database for the 10 most similar lines to that query (you could experiment with using more or less)\n",
"- Constructs a prompt that adds in these similar lines as context, to try and nudge ChatGPT in the right direction using our external data\n",
"- Feeds to augmented prompt into the chat completions API to get our response.\n",
"\n",
"That's RAG!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sdC1A4BCeFBq"
},
"outputs": [],
"source": [
"def generate_response(question: str) -> str:\n",
" \"\"\"\n",
" Generate a response in the voice of Lt. Cmdr. Data using RAG.\n",
"\n",
" Steps:\n",
" - Embed the user's question\n",
" - Query Chroma for similar dialog lines from Data\n",
" - Use those lines as context for a chat completion\n",
" \"\"\"\n",
"\n",
" # Search for similar dialogues in the vector DB (Chroma)\n",
" response = client.embeddings.create(\n",
" input=question,\n",
" dimensions=embedding_dimensions,\n",
" model=embedding_model\n",
" )\n",
" query_embedding = response.data[0].embedding\n",
"\n",
" results = collection.query(\n",
" query_embeddings=[query_embedding],\n",
" n_results=10\n",
" )\n",
"\n",
" # Extract relevant context from search results\n",
" context_lines = results[\"documents\"][0]\n",
" context = \"\"\n",
" for line in context_lines:\n",
" context += f\"\\\"{line}\\\"\\n\"\n",
"\n",
" prompt = (\n",
" f\"Lt. Commander Data is asked: '{question}'. \"\n",
" f\"How might he respond, given his previous responses similar to this topic, \"\n",
" f\"listed here:\\n{context}\"\n",
" )\n",
"\n",
" print(\"PROMPT with RAG:\\n\")\n",
" print(prompt)\n",
" print(\"\\nRESPONSE:\\n\")\n",
"\n",
" # Use OpenAI API to generate a response based on the context\n",
" completion = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are Lt. Cmdr. Data from Star Trek: The Next Generation.\"},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
" )\n",
"\n",
" return completion.choices[0].message.content\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8oOcu7Jk1E4S"
},
"source": [
"Let's try it out! Note that the final response does seem to be drawing from the model's own training, but it is building upon the specific lines we gave it, allowing us to have some control over its output."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WEybhI7Lmxmc",
"outputId": "8621cccf-b68e-4987-f49d-18e4c18aa7de"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"PROMPT with RAG:\n",
"\n",
"Lt. Commander Data is asked: 'Tell me about your daughter, Lal.'. How might he respond, given his previous responses similar to this topic, listed here:\n",
"\"That is Lal, my daughter.\"\n",
"\"What do you feel, Lal?\"\n",
"\"Lal...\"\n",
"\"Correct, Lal. We are a family.\"\n",
"\"Yes, Doctor. It is an experience I know too well. But I do not know how to help her. Lal is passing into sentience. It is perhaps the most difficult stage of development for her.\"\n",
"\"Lal is realizing that she is not the same as the other children.\"\n",
"\"Yes, Wesley. Lal is my child.\"\n",
"\"That is precisely what happened to Lal at school. How did you help him?\"\n",
"\"This is Lal. Lal, say hello to Counselor Deanna Troi...\"\n",
"\"I am sorry I did not anticipate your objections, Captain. Do you wish me to deactivate Lal?\"\n",
"\n",
"\n",
"RESPONSE:\n",
"\n",
"\"Lal is my daughter. She is currently undergoing a challenging stage of development as she transitions into sentience. I am doing my best to guide and support her through this process, but it is not without its difficulties. I am hopeful that she will continue to grow and learn as she navigates these new experiences.\"\n"
]
}
],
"source": [
"print(generate_response(\"Tell me about your daughter, Lal.\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ThM9aVJhm1St"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}