{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1f7e3a8e", "metadata": {}, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "49443595", "metadata": { "lines_to_next_cell": 2, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Question Answering with Retrieval\n", "\n", "Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb)\n", "\n", "(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "f5183ea7", "metadata": {}, "outputs": [], "source": [ "import datasets\n", "import numpy as np\n", "from minichain import prompt, show, OpenAIEmbed, OpenAI\n", "from manifest import Manifest" ] }, { "cell_type": "markdown", "id": "2bf59f0d", "metadata": {}, "source": [ "We use Hugging Face Datasets as the database by assigning\n", "a FAISS index." ] }, { "cell_type": "code", "execution_count": null, "id": "f371a85e", "metadata": {}, "outputs": [], "source": [ "olympics = datasets.load_from_disk(\"olympics.data\")\n", "olympics.add_faiss_index(\"embeddings\")" ] }, { "cell_type": "markdown", "id": "a1099002", "metadata": {}, "source": [ "Fast KNN retieval prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "6881ae0e", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAIEmbed())\n", "def get_neighbors(model, inp, k):\n", " embedding = model(inp)\n", " res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n", " return res.examples[\"content\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "59cc1355", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(),\n", " template_file=\"qa.pmpt.tpl\")\n", "def get_result(model, query, neighbors):\n", " return model(dict(question=query, docs=neighbors))" ] }, { "cell_type": "code", "execution_count": null, "id": "cb2f1101", "metadata": {}, "outputs": [], "source": [ "def qa(query):\n", " n = get_neighbors(query, 3)\n", " return get_result(query, n)" ] }, { "cell_type": "code", "execution_count": null, "id": "5f70bac7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "abdfcd87", "metadata": {}, "outputs": [], "source": [ "questions = [\"Who won the 2020 Summer Olympics men's high jump?\",\n", " \"Why was the 2020 Summer Olympics originally postponed?\",\n", " \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\",\n", " \"What is the total number of medals won by France?\",\n", " \"What is the tallest mountain in the world?\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "ddce3ec3", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "gradio = show(qa,\n", " examples=questions,\n", " subprompts=[get_neighbors, get_result],\n", " description=desc,\n", " )\n", "if __name__ == \"__main__\":\n", " gradio.launch()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "tags,-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }