{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b57f44dc", "metadata": {}, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "f675b525", "metadata": { "lines_to_next_cell": 2, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Book QA\n", "\n", "Chain that does question answering with Hugging Face embeddings. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/gatsby.ipynb)\n", "\n", "(Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).)\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "eab06e25", "metadata": {}, "outputs": [], "source": [ "import datasets\n", "import numpy as np\n", "from minichain import prompt, show, HuggingFaceEmbed, OpenAI" ] }, { "cell_type": "markdown", "id": "ad893c7e", "metadata": {}, "source": [ "Load data with embeddings (computed beforehand)" ] }, { "cell_type": "code", "execution_count": null, "id": "c64cbf9f", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "gatsby = datasets.load_from_disk(\"gatsby\")\n", "gatsby.add_faiss_index(\"embeddings\")" ] }, { "cell_type": "markdown", "id": "20b8d756", "metadata": {}, "source": [ "Fast KNN retieval prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "a786c893", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(HuggingFaceEmbed(\"sentence-transformers/all-mpnet-base-v2\"))\n", "def get_neighbors(model, inp, k=1):\n", " embedding = model(inp)\n", " res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n", " return res.examples[\"passages\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "fac56e96", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(),\n", " template_file=\"gatsby.pmpt.tpl\")\n", "def ask(model, query, neighbors):\n", " return model(dict(question=query, docs=neighbors))" ] }, { "cell_type": "code", "execution_count": null, "id": "4e46761b", "metadata": {}, "outputs": [], "source": [ "def gatsby(query):\n", " n = get_neighbors(query)\n", " return ask(query, n)" ] }, { "cell_type": "code", "execution_count": null, "id": "e9bafe2a", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "913793a8", "metadata": {}, "outputs": [], "source": [ "gradio = show(gatsby,\n", " subprompts=[get_neighbors, ask],\n", " examples=[\"What did Gatsby do before he met Daisy?\",\n", " \"What did the narrator do after getting back to Chicago?\"],\n", " keys={\"HF_KEY\"},\n", " description=desc,\n", " )\n", "if __name__ == \"__main__\":\n", " gradio.launch()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "tags,-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }