diff --git "a/notebooks/populate_dataset.ipynb" "b/notebooks/populate_dataset.ipynb" new file mode 100644--- /dev/null +++ "b/notebooks/populate_dataset.ipynb" @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3b5fbb8f-5789-45db-a551-f0e6633b4f46", + "metadata": {}, + "source": [ + "# Populate a HDF5 dataset with base64 Pokémon images keyed by energy type" + ] + }, + { + "cell_type": "markdown", + "id": "a9234c78-1ac5-4c71-b11b-3a81be57f3f3", + "metadata": {}, + "source": [ + "Used in [**This Pokémon Does Not Exist**](https://huggingface.co/spaces/ronvolutional/ai-pokemon-card)\n", + "\n", + "Model fine-tuned by [**Max Woolf**](https://huggingface.co/minimaxir/ai-generated-pokemon-rudalle)\n", + "\n", + "ruDALL-E by [**Sber**](https://rudalle.ru/en)" + ] + }, + { + "cell_type": "markdown", + "id": "09850949-235d-4997-b8e3-c5b2aeffe109", + "metadata": {}, + "source": [ + "## Initialise datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cead6fa8-e9ef-4672-bbfb-beadcaf5f3a0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import h5py\n", + "\n", + "datasets_dir = './datasets'\n", + "datasets_file = 'pregenerated_pokemon.h5'\n", + "h5_file = os.path.join(datasets_dir, datasets_file)\n", + "\n", + "energy_types = ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']" + ] + }, + { + "cell_type": "raw", + "id": "2df90e94-15c0-4eb6-914e-875ec80b7c24", + "metadata": {}, + "source": [ + "# Only run if the datasets file does not exist\n", + "\n", + "with h5py.File(h5_file, 'x') as datasets:\n", + " for energy in energy_types:\n", + " datasets.create_dataset(energy, (0,1), h5py.string_dtype(encoding='utf-8'), maxshape=(None,1))\n", + "\n", + " print(datasets.keys())" + ] + }, + { + "cell_type": "markdown", + "id": "cdd3eb59-bbf5-4b85-b6bc-35f591317b47", + "metadata": {}, + "source": [ + "### Dataset functions" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fca1947f-8a66-4636-8049-99d6ff0ace93", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from time import gmtime, strftime, time\n", + "from random import choices, randint\n", + "from IPython import display\n", + "\n", + "def get_stats(h5_file=h5_file):\n", + " with h5py.File(h5_file, 'r') as datasets:\n", + " return {\n", + " \"size_counts\": {key: datasets[key].size.item() for key in datasets.keys()},\n", + " \"size_total\": sum(list(datasets[energy].size.item() for energy in datasets.keys())),\n", + " \"size_mb\": round(os.path.getsize(h5_file) / 1024**2, 1)\n", + " }\n", + "\n", + "\n", + "def add_row(energy, image):\n", + " with h5py.File(h5_file, 'r+') as datasets:\n", + " dataset = datasets[energy]\n", + " dataset.resize(dataset.size + 1, 0)\n", + " dataset[-1] = image\n", + "\n", + "\n", + "def get_image(energy=None, row=None):\n", + " if not energy:\n", + " energy = choices(energy_types)[0]\n", + "\n", + " with h5py.File(h5_file, 'r') as datasets:\n", + " if not row:\n", + " row = randint(0, datasets[energy].size - 1)\n", + "\n", + " return datasets[energy].asstr()[row][0]\n", + "\n", + "def pretty_time(seconds):\n", + " m, s = divmod(seconds, 60)\n", + " h, m = divmod(m, 60)\n", + " return f\"{f'{math.floor(h)}h ' if h else ''}{f'{math.floor(m)}m ' if m else ''}{f'{math.floor(s)}s' if s else ''}\"\n", + " \n", + "def populate_dataset(batches=1, batch_size=1, image_cap=100_000, filesize_cap=4_000):\n", + " initial_stats = get_stats()\n", + "\n", + " iterations = 0\n", + " start_time = time()\n", + "\n", + " while iterations < batches and get_stats()['size_total'] < image_cap and get_stats()['size_mb'] < filesize_cap:\n", + " for energy in energy_types:\n", + " current = get_stats()\n", + " new_images_count = (current['size_total'] - initial_stats['size_total'])\n", + " new_mb_count = round(current['size_mb'] - initial_stats['size_mb'], 1)\n", + " elapsed = time() - start_time\n", + " eta_total = elapsed / (new_images_count or 1) * batches * batch_size * len(energy_types)\n", + "\n", + " display.clear_output(wait=True)\n", + " if new_images_count:\n", + " print(f\"ETA: {pretty_time(eta_total - elapsed)} left of {pretty_time(eta_total)}\")\n", + " print(f\"Images in dataset: {current['size_total']}{f' (+{new_images_count})' if new_images_count else ''}\")\n", + " print(f\"Size of dataset: {current['size_mb']}MB{f' (+{new_mb_count}MB)' if new_mb_count else ''}\")\n", + " print(f\"Batch {iterations + 1} of {batches}:\")\n", + " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Generating {batch_size} {energy} Pokémon...\")\n", + "\n", + " generate_pokemon(energy, batch_size)\n", + "\n", + " iterations += 1\n", + "\n", + " new_stats = get_stats()\n", + " elapsed = time() - start_time\n", + "\n", + " display.clear_output(wait=True)\n", + " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Finished populating dataset with {batches} {'batches' if batches > 1 else 'batch'} after {pretty_time(elapsed)}\")\n", + " print(f\"Images in dataset: {new_stats['size_total']} (+{new_stats['size_total'] - initial_stats['size_total']})\")\n", + " print(f\"Size of dataset: {new_stats['size_mb']}MB (+{round(new_stats['size_mb'] - initial_stats['size_mb'], 1)}MB)\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0da582a-2b29-4df6-8a3f-ddd56377af16", + "metadata": {}, + "source": [ + "## Load Pokémon model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ad435db5-bd35-4440-87b2-5a108f4ae385", + "metadata": {}, + "outputs": [], + "source": [ + "from rudalle import get_rudalle_model, get_tokenizer, get_vae\n", + "from huggingface_hub import cached_download, hf_hub_url\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5a2a4e6a-2086-4f98-b2b5-65a3631be61e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPUs available: 1\n" + ] + } + ], + "source": [ + "print(f\"GPUs available: {torch.cuda.device_count()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "720df30d-f42c-406a-92ba-4a465f6ff1d3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n", + "vae --> ready\n", + "tokenizer --> ready\n", + "GPU[0] memory: 11263Mib\n", + "GPU[0] memory reserved: 5144Mib\n", + "GPU[0] memory allocated: 2767Mib\n" + ] + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "fp16 = torch.cuda.is_available()\n", + "map_location = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "file_dir = \"./models\"\n", + "file_name = \"pytorch_model.bin\"\n", + "config_file_url = hf_hub_url(repo_id=\"minimaxir/ai-generated-pokemon-rudalle\", filename=file_name)\n", + "cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)\n", + "\n", + "model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)\n", + "model.load_state_dict(torch.load(f\"{file_dir}/{file_name}\", map_location=map_location))\n", + "\n", + "vae = get_vae().to(device)\n", + "tokenizer = get_tokenizer()\n", + "\n", + "print(f\"GPU[0] memory: {int(torch.cuda.get_device_properties(0).total_memory / 1024**2)}Mib\")\n", + "print(f\"GPU[0] memory reserved: {int(torch.cuda.memory_reserved(0) / 1024**2)}Mib\")\n", + "print(f\"GPU[0] memory allocated: {int(torch.cuda.memory_allocated(0) / 1024**2)}Mib\")" + ] + }, + { + "cell_type": "markdown", + "id": "88d413a4-a8c9-401e-9cb5-32a1ae34c179", + "metadata": {}, + "source": [ + "### Model functions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0624c686-c75f-46c6-afae-bbe6c455caa1", + "metadata": {}, + "outputs": [], + "source": [ + "import base64\n", + "from io import BytesIO\n", + "from time import gmtime, strftime, time\n", + "from rudalle.pipelines import generate_images\n", + "\n", + "def english_to_russian(english):\n", + " word_map = {\n", + " \"colorless\": \"Покемон нормального типа\",\n", + " \"dragon\": \"Покемон типа дракона\",\n", + " \"darkness\": \"Покемон темного типа\",\n", + " \"fairy\": \"Покемон фея\",\n", + " \"fighting\": \"Покемон боевого типа\",\n", + " \"fire\": \"Покемон огня\",\n", + " \"grass\": \"Покемон трава\",\n", + " \"lightning\": \"Покемон электрического типа\",\n", + " \"metal\": \"Покемон из стали типа\",\n", + " \"psychic\": \"Покемон психического типа\",\n", + " \"water\": \"Покемон в воду\"\n", + " }\n", + "\n", + " return word_map[english.lower()]\n", + "\n", + "\n", + "def generate_pokemon(energy, num=1):\n", + " if energy in energy_types:\n", + " russian_prompt = english_to_russian(energy)\n", + " \n", + " images, _ = generate_images(russian_prompt, tokenizer, model, vae, top_k=2048, images_num=num, top_p=0.995)\n", + " \n", + " for image in images:\n", + " buffer = BytesIO()\n", + " image.save(buffer, format=\"JPEG\", quality=100, optimize=True)\n", + " base64_bytes = base64.b64encode(buffer.getvalue())\n", + " base64_string = base64_bytes.decode(\"UTF-8\")\n", + " base64_image = \"data:image/jpeg;base64,\" + base64_string\n", + " add_row(energy, base64_image)" + ] + }, + { + "cell_type": "markdown", + "id": "7b309b8e-0c34-4a80-a411-8f093a105494", + "metadata": {}, + "source": [ + "## Populate dataset" + ] + }, + { + "cell_type": "markdown", + "id": "a96d026f-e300-40c7-88a6-20be0012c584", + "metadata": {}, + "source": [ + "Total number of images per population = `batches` × `len(energy_types)` (11) × `batch_size`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "74d2c50a-93ae-4040-89a0-87f818187bbb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022-03-16 05:07:48 Finished populating dataset with 1 batch after 10m 8s\n", + "Images in dataset: 5082 (+66)\n", + "Size of dataset: 199.8MB (+2.5MB)\n" + ] + } + ], + "source": [ + "batches = 1\n", + "batch_size = 6\n", + "image_cap = 100_000\n", + "filesize_cap = 4_000 # MB\n", + "\n", + "populate_dataset(batches, batch_size, image_cap, filesize_cap)" + ] + }, + { + "cell_type": "markdown", + "id": "4eb6f750-aa2d-42b9-89dc-2cb103e8869e", + "metadata": {}, + "source": [ + "## Getting images" + ] + }, + { + "cell_type": "markdown", + "id": "9365ac7b-36e4-42b9-8380-a039d556356b", + "metadata": {}, + "source": [ + "### Random image" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "dca3bc8e-1f65-4566-8385-bf6b9f20eaaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "display.HTML(f'')" + ] + }, + { + "cell_type": "markdown", + "id": "f3b9278e-b6dc-4bea-9ef6-3aa312739718", + "metadata": {}, + "source": [ + "### Random image of specific energy type" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b6e98817-19f7-44f3-8970-0a11b01cf37b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "display.HTML(f'')" + ] + }, + { + "cell_type": "markdown", + "id": "34ca6f96-625b-459a-ba90-2c58a1d0ea47", + "metadata": {}, + "source": [ + "### Specific image" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5a60fc60-4198-4318-a5b4-26095cb2c0bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "display.HTML(f'')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}