Upload 4 files
Browse files- data_prep.ipynb +1013 -0
- data_prep.pdf +0 -0
- training.ipynb +472 -0
- training.pdf +0 -0
data_prep.ipynb
ADDED
|
@@ -0,0 +1,1013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "8198fee9-000e-4ef9-bb13-82c649c2e816",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Data prep for retrieving beliefs for dialogs\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Goal:** Create a dataset to match dialogs with (possibly) relevant facts \n",
|
| 11 |
+
" \n",
|
| 12 |
+
"**Method:**\n",
|
| 13 |
+
"- [x] Use stacked_samsum as training dataset\n",
|
| 14 |
+
"- [x] Prepare datasets\n",
|
| 15 |
+
" - [x] remove unnecessary columns\n",
|
| 16 |
+
" - [x] expand the stacked dataset\n",
|
| 17 |
+
" - [x] truncate on the right to create dangling examples\n",
|
| 18 |
+
" - [x] augment dialog using openai to make longer"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"id": "fe53fc09-0942-4e9a-921c-3804a1ede8ac",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"### Constants"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": 2,
|
| 32 |
+
"id": "94dea7bd-f87b-4559-bd82-dadf3dfd6025",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"model_name = \"BAAI/bge-small-en-v1.5\"\n",
|
| 37 |
+
"max_len = 512\n",
|
| 38 |
+
"next_concept_sep = \"\\n[NEXT_CONCEPT]\\n\"\n",
|
| 39 |
+
"training_input_file = \"./data/train-soft.jsonl\"\n",
|
| 40 |
+
"eval_input_file = \"./data/eval.jsonl\"\n",
|
| 41 |
+
"training_hn_file = \"./data/train.jsonl\"\n",
|
| 42 |
+
"eval_size = 12_500\n",
|
| 43 |
+
"seed = 42\n",
|
| 44 |
+
"query_prefix = \"Represent this sentence for searching relevant passages: \"\n",
|
| 45 |
+
"hf_repo_name = \"julep-ai/dfe-stacked_samsum\""
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "markdown",
|
| 50 |
+
"id": "6a1ec397-3b13-4e2b-8e0f-9cf127378b8f",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"source": [
|
| 53 |
+
"### Imports and utils"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": 3,
|
| 59 |
+
"id": "7b69b396-1ef2-41f7-aea8-76cf902dec8b",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"from functools import partial\n",
|
| 64 |
+
"import os\n",
|
| 65 |
+
"import random\n",
|
| 66 |
+
"import time\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"from datasets import load_dataset, load_from_disk\n",
|
| 69 |
+
"from FlagEmbedding import FlagModel\n",
|
| 70 |
+
"from FlagEmbedding.baai_general_embedding.finetune.hn_mine import find_knn_neg\n",
|
| 71 |
+
"from huggingface_hub import HfApi\n",
|
| 72 |
+
"import jsonlines as jsonl\n",
|
| 73 |
+
"import langchain\n",
|
| 74 |
+
"from langchain.cache import SQLiteCache\n",
|
| 75 |
+
"from langchain.llms import OpenAI\n",
|
| 76 |
+
"from langchain.prompts import PromptTemplate\n",
|
| 77 |
+
"from math import ceil\n",
|
| 78 |
+
"from numpy import cumsum, dot\n",
|
| 79 |
+
"from numpy.linalg import norm\n",
|
| 80 |
+
"from tqdm.auto import tqdm\n",
|
| 81 |
+
"from transformers import AutoTokenizer"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "markdown",
|
| 86 |
+
"id": "8b7b4bfb-5b60-4a76-903d-cb528731745a",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"source": [
|
| 89 |
+
"#### Tokenizer"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 3,
|
| 95 |
+
"id": "7656e742-9baa-4acc-b536-b2a861fd1d75",
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "markdown",
|
| 104 |
+
"id": "5473558d-45bb-430a-9d0d-9679ea6e2bcd",
|
| 105 |
+
"metadata": {},
|
| 106 |
+
"source": [
|
| 107 |
+
"#### LLM"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": 5,
|
| 113 |
+
"id": "7dedef47-411d-4803-a2a5-4789f668e4ad",
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"source": [
|
| 117 |
+
"langchain.llm_cache = SQLiteCache(database_path=\".langchain.db\")\n",
|
| 118 |
+
"llm = OpenAI(model_name=\"gpt-3.5-turbo-instruct\", temperature=0.7)"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": 6,
|
| 124 |
+
"id": "552f665a-4d32-40d2-8269-ed6031473aec",
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"outputs": [],
|
| 127 |
+
"source": [
|
| 128 |
+
"prompt_template = PromptTemplate.from_template(\n",
|
| 129 |
+
"\"\"\"\\\n",
|
| 130 |
+
"You are a dialog writer. Given a dialog continue it for {n} more turns in the same style as the original speakers. You can be creative in coming up with the next turns as long as you make sure that the new dialog is consistent with the previous messages.\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"### Example Dialog\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"Ken: Hi, how are you?\n",
|
| 135 |
+
"Ang: Just peachy! You?\n",
|
| 136 |
+
"Ken: I'm okay...\n",
|
| 137 |
+
"Ang: Just okay? What's wrong?\n",
|
| 138 |
+
"Ken: Just stressed; work stuff, fighting with Brad, too much going on at mom's.\n",
|
| 139 |
+
"Ang: Hang in there, it will get better!\n",
|
| 140 |
+
"Ken: I know, but it's a lot.\n",
|
| 141 |
+
"Ang: Can I do anything to help?\n",
|
| 142 |
+
"Ken: You are! Listening to me vent! LOL!\n",
|
| 143 |
+
"Ang: Are you at least doing anything fun this weekend?\n",
|
| 144 |
+
"Ken: Show Saturday night, then seeing the grandkids on Sunday at the zoo.\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"### Continuation\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"Ang: Sounds great! That will cheer you up!\n",
|
| 149 |
+
"Ken: Gotta run, work calls. Love you!\n",
|
| 150 |
+
"Ang: Love you too! Have a fantastic day!\n",
|
| 151 |
+
"Ken: You too!\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"### Input Dialog\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"{input_dialog}\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"### Continuation\n",
|
| 158 |
+
"\"\"\"\n",
|
| 159 |
+
")\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"def gen_continuation(input_dialog, n=4):\n",
|
| 162 |
+
" wait = round(random.uniform(0.3, 1.2), 3)\n",
|
| 163 |
+
" time.sleep(wait)\n",
|
| 164 |
+
"\n",
|
| 165 |
+
" prompt = prompt_template.format(n=n, input_dialog=input_dialog)\n",
|
| 166 |
+
" continuation = llm(prompt).strip()\n",
|
| 167 |
+
" \n",
|
| 168 |
+
" return continuation"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "markdown",
|
| 173 |
+
"id": "2eb6f55d-ec09-4bc5-8f1a-31e521ad3121",
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"source": [
|
| 176 |
+
"#### Dataset load"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": 7,
|
| 182 |
+
"id": "3f5420aa-d327-4d3a-8e02-90473dcca1be",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": [
|
| 186 |
+
"# Get everything, we'll split it later\n",
|
| 187 |
+
"dataset = load_dataset(\n",
|
| 188 |
+
" \"stacked-summaries/stacked-samsum-1024\", \n",
|
| 189 |
+
")\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"# Remove unnecessary columns\n",
|
| 193 |
+
"dataset = dataset.remove_columns(['chapter_length', 'summary_length', 'is_stacked',])\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"# Remove empty/null dialogs\n",
|
| 196 |
+
"dataset = dataset.filter(\n",
|
| 197 |
+
" lambda row: row[\"dialogue\"]\n",
|
| 198 |
+
")\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"# Convert windows-style line endings to unix-style\n",
|
| 201 |
+
"dataset = dataset.map(\n",
|
| 202 |
+
" lambda row: dict(dialogue=row[\"dialogue\"].replace(\"\\r\\n\", '\\n'))\n",
|
| 203 |
+
")"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "markdown",
|
| 208 |
+
"id": "1d728969-c3bc-42e5-8a49-2e8fb16f582c",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"source": [
|
| 211 |
+
"#### Dataset prep"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": 8,
|
| 217 |
+
"id": "c56780b7-1e2f-458d-b370-82b6c95f5173",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"def count_tokens(row):\n",
|
| 222 |
+
" \"\"\"Count tokens using the tokenizer\"\"\"\n",
|
| 223 |
+
"\n",
|
| 224 |
+
" dialogue = row[\"dialogue\"]\n",
|
| 225 |
+
" tokens = tokenizer.encode(dialogue, add_special_tokens=False)\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" return dict(token_count=len(tokens))"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "code",
|
| 232 |
+
"execution_count": 9,
|
| 233 |
+
"id": "416b074f-9660-40c3-9774-7ea17bfae5bb",
|
| 234 |
+
"metadata": {},
|
| 235 |
+
"outputs": [],
|
| 236 |
+
"source": [
|
| 237 |
+
"# Add token count to every row in dataset\n",
|
| 238 |
+
"dataset = dataset.map(count_tokens)"
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"execution_count": 10,
|
| 244 |
+
"id": "5c3666f1-0457-4304-aeff-10060405f72e",
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"outputs": [],
|
| 247 |
+
"source": [
|
| 248 |
+
"def offset_left(\n",
|
| 249 |
+
" dialogue: str,\n",
|
| 250 |
+
" split_offset=0,\n",
|
| 251 |
+
" splits=1,\n",
|
| 252 |
+
" max_len=max_len,\n",
|
| 253 |
+
"):\n",
|
| 254 |
+
" # Split dialog lines\n",
|
| 255 |
+
" lines = dialogue.split(\"\\n\")\n",
|
| 256 |
+
"\n",
|
| 257 |
+
" # Count tokens per line\n",
|
| 258 |
+
" toks_by_line = [\n",
|
| 259 |
+
" len(tokenizer.encode(line, add_special_tokens=False))\n",
|
| 260 |
+
" for line in lines\n",
|
| 261 |
+
" ]\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" # Cumulative sum of tokens per line\n",
|
| 264 |
+
" cum_toks_by_line = cumsum(toks_by_line)\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" # Total no. of tokens\n",
|
| 267 |
+
" total_tokens = sum(toks_by_line)\n",
|
| 268 |
+
"\n",
|
| 269 |
+
" # Return as is if total tokens is less than max len of model\n",
|
| 270 |
+
" if total_tokens <= max_len:\n",
|
| 271 |
+
" return dialogue\n",
|
| 272 |
+
"\n",
|
| 273 |
+
" # Calculate step size\n",
|
| 274 |
+
" step_size = ceil(total_tokens / (splits * 2))\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" # Calculate left index\n",
|
| 277 |
+
" left_index = 0\n",
|
| 278 |
+
" for cum_toks in cum_toks_by_line:\n",
|
| 279 |
+
" if cum_toks > (split_offset * step_size):\n",
|
| 280 |
+
" break\n",
|
| 281 |
+
" \n",
|
| 282 |
+
" left_index += 1\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" # Calculate right index\n",
|
| 285 |
+
" right_index = 0\n",
|
| 286 |
+
" for last_cum_toks in cum_toks_by_line[::-1]:\n",
|
| 287 |
+
" if last_cum_toks < max_len:\n",
|
| 288 |
+
" break\n",
|
| 289 |
+
" \n",
|
| 290 |
+
" right_index -= 1\n",
|
| 291 |
+
"\n",
|
| 292 |
+
" # Calc final section\n",
|
| 293 |
+
" if right_index == 0:\n",
|
| 294 |
+
" lines = lines[left_index:]\n",
|
| 295 |
+
" else:\n",
|
| 296 |
+
" lines = lines[left_index:right_index]\n",
|
| 297 |
+
"\n",
|
| 298 |
+
" return \"\\n\".join(lines)"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "code",
|
| 303 |
+
"execution_count": 11,
|
| 304 |
+
"id": "580d654b-ed6a-4cf5-b81a-886905d0bd30",
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"outputs": [],
|
| 307 |
+
"source": [
|
| 308 |
+
"def truncate_lines(dialog, num=3, min=5):\n",
|
| 309 |
+
" \"\"\"\n",
|
| 310 |
+
" Split dialog into lines and then drop the last `num` lines,\n",
|
| 311 |
+
" making sure there are at least `min` lines remaining.\n",
|
| 312 |
+
" \"\"\"\n",
|
| 313 |
+
" \n",
|
| 314 |
+
" lines = dialog.split(\"\\n\")\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" # If too short, return as is\n",
|
| 317 |
+
" if len(lines) - num < min:\n",
|
| 318 |
+
" return dialog\n",
|
| 319 |
+
"\n",
|
| 320 |
+
" if num > 0:\n",
|
| 321 |
+
" return \"\\n\".join(lines[:-num])\n",
|
| 322 |
+
" else:\n",
|
| 323 |
+
" return \"\\n\".join(lines[-num:])\n"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": 12,
|
| 329 |
+
"id": "6f8b5214-1f51-4974-8c20-b3e4a6aa33ab",
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [],
|
| 332 |
+
"source": [
|
| 333 |
+
"def expand_stacked(rows):\n",
|
| 334 |
+
" \"\"\"Expand stacked samsum dataset by splitting concepts in every summary per dialog\"\"\"\n",
|
| 335 |
+
" \n",
|
| 336 |
+
" # Get fields by batch\n",
|
| 337 |
+
" dialogues = rows[\"dialogue\"]\n",
|
| 338 |
+
" summaries = rows[\"summary\"]\n",
|
| 339 |
+
"\n",
|
| 340 |
+
" # Containers for final results\n",
|
| 341 |
+
" is_augmented = []\n",
|
| 342 |
+
" is_truncated = []\n",
|
| 343 |
+
" final_dialogues = []\n",
|
| 344 |
+
" final_summaries = []\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" # Process every dialog and summary\n",
|
| 347 |
+
" for dialogue, summary in tqdm(zip(dialogues, summaries)):\n",
|
| 348 |
+
" # Split the summary by the NEXT_CONCEPT separator from the dataset\n",
|
| 349 |
+
" ss = summary.split(next_concept_sep)\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" # Split different conversations within the sample\n",
|
| 352 |
+
" # offset on the left to try to match relevance\n",
|
| 353 |
+
" dd = [\n",
|
| 354 |
+
" offset_left(d, split_offset=1) for d in dialogue.split(\"\\n\\n\")\n",
|
| 355 |
+
" ]\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" is_truncated += [False] * len(dd)\n",
|
| 358 |
+
" is_augmented += [False] * len(dd)\n",
|
| 359 |
+
" final_dialogues += dd\n",
|
| 360 |
+
" final_summaries += ss\n",
|
| 361 |
+
"\n",
|
| 362 |
+
" # ---\n",
|
| 363 |
+
" # Now truncate and add\n",
|
| 364 |
+
" truncated = [truncate_lines(d) for d in dd]\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" is_augmented += [False] * len(dd)\n",
|
| 367 |
+
" is_truncated += [t != d for t, d in zip(truncated, dd)]\n",
|
| 368 |
+
" final_dialogues += truncated\n",
|
| 369 |
+
" final_summaries += ss\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" # ---\n",
|
| 372 |
+
" # Now augment and add\n",
|
| 373 |
+
"\n",
|
| 374 |
+
" augmented = [\n",
|
| 375 |
+
" truncate_lines(d + gen_continuation(d), num=-4)\n",
|
| 376 |
+
" for d in dd\n",
|
| 377 |
+
" ]\n",
|
| 378 |
+
" \n",
|
| 379 |
+
" is_truncated += [False] * len(dd)\n",
|
| 380 |
+
" is_augmented += [True] * len(dd)\n",
|
| 381 |
+
" final_dialogues += augmented\n",
|
| 382 |
+
" final_summaries += ss\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" return dict(\n",
|
| 385 |
+
" dialogue=final_dialogues,\n",
|
| 386 |
+
" summary=final_summaries,\n",
|
| 387 |
+
" is_truncated=is_truncated,\n",
|
| 388 |
+
" token_count=[None]*len(final_summaries),\n",
|
| 389 |
+
" )"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": 13,
|
| 395 |
+
"id": "e79f4bb3-614a-4a5a-9135-fda2dce33c55",
|
| 396 |
+
"metadata": {
|
| 397 |
+
"scrolled": true
|
| 398 |
+
},
|
| 399 |
+
"outputs": [
|
| 400 |
+
{
|
| 401 |
+
"name": "stderr",
|
| 402 |
+
"output_type": "stream",
|
| 403 |
+
"text": [
|
| 404 |
+
"Parameter 'function'=<function expand_stacked at 0x7f0a3a68eef0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
|
| 405 |
+
]
|
| 406 |
+
},
|
| 407 |
+
{
|
| 408 |
+
"data": {
|
| 409 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 410 |
+
"model_id": "091a1ff1b3c34d1b8cee91d5468e48a8",
|
| 411 |
+
"version_major": 2,
|
| 412 |
+
"version_minor": 0
|
| 413 |
+
},
|
| 414 |
+
"text/plain": [
|
| 415 |
+
"Map (num_proc=75): 0%| | 0/29441 [00:00<?, ? examples/s]"
|
| 416 |
+
]
|
| 417 |
+
},
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"output_type": "display_data"
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"data": {
|
| 423 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 424 |
+
"model_id": "a41b9133bf5a4974b5525a1406590bc0",
|
| 425 |
+
"version_major": 2,
|
| 426 |
+
"version_minor": 0
|
| 427 |
+
},
|
| 428 |
+
"text/plain": [
|
| 429 |
+
"Map (num_proc=75): 0%| | 0/1633 [00:00<?, ? examples/s]"
|
| 430 |
+
]
|
| 431 |
+
},
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"output_type": "display_data"
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"data": {
|
| 437 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 438 |
+
"model_id": "036f4b46482141bb89cc7924767c8427",
|
| 439 |
+
"version_major": 2,
|
| 440 |
+
"version_minor": 0
|
| 441 |
+
},
|
| 442 |
+
"text/plain": [
|
| 443 |
+
"Map (num_proc=75): 0%| | 0/1637 [00:00<?, ? examples/s]"
|
| 444 |
+
]
|
| 445 |
+
},
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"output_type": "display_data"
|
| 448 |
+
}
|
| 449 |
+
],
|
| 450 |
+
"source": [
|
| 451 |
+
"# Use batched mode to be able to expand the size of the dataset\n",
|
| 452 |
+
"dataset = dataset.map(expand_stacked, batch_size=10, batched=True, num_proc=75)\n",
|
| 453 |
+
"dataset = dataset.remove_columns([\"token_count\"])"
|
| 454 |
+
]
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"cell_type": "code",
|
| 458 |
+
"execution_count": 14,
|
| 459 |
+
"id": "22beb7aa-f191-4660-a860-ef4169c229b1",
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"outputs": [
|
| 462 |
+
{
|
| 463 |
+
"data": {
|
| 464 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 465 |
+
"model_id": "44e4d3202e914fcf9b388e47c70d5e28",
|
| 466 |
+
"version_major": 2,
|
| 467 |
+
"version_minor": 0
|
| 468 |
+
},
|
| 469 |
+
"text/plain": [
|
| 470 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 471 |
+
]
|
| 472 |
+
},
|
| 473 |
+
"metadata": {},
|
| 474 |
+
"output_type": "display_data"
|
| 475 |
+
},
|
| 476 |
+
{
|
| 477 |
+
"data": {
|
| 478 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 479 |
+
"model_id": "d02fecde242d4109a9f88fbcaf55ec6b",
|
| 480 |
+
"version_major": 2,
|
| 481 |
+
"version_minor": 0
|
| 482 |
+
},
|
| 483 |
+
"text/plain": [
|
| 484 |
+
"Creating parquet from Arrow format: 0%| | 0/339 [00:00<?, ?ba/s]"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
"metadata": {},
|
| 488 |
+
"output_type": "display_data"
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"data": {
|
| 492 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 493 |
+
"model_id": "95cdcc57a7b94830af4a5661d087df9a",
|
| 494 |
+
"version_major": 2,
|
| 495 |
+
"version_minor": 0
|
| 496 |
+
},
|
| 497 |
+
"text/plain": [
|
| 498 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 499 |
+
]
|
| 500 |
+
},
|
| 501 |
+
"metadata": {},
|
| 502 |
+
"output_type": "display_data"
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"data": {
|
| 506 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 507 |
+
"model_id": "a1289ee9a1ac4a8fa39ac9e26cc2b360",
|
| 508 |
+
"version_major": 2,
|
| 509 |
+
"version_minor": 0
|
| 510 |
+
},
|
| 511 |
+
"text/plain": [
|
| 512 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"output_type": "display_data"
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"data": {
|
| 520 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 521 |
+
"model_id": "4a2c5e8ccbc04b56a41071401d493b66",
|
| 522 |
+
"version_major": 2,
|
| 523 |
+
"version_minor": 0
|
| 524 |
+
},
|
| 525 |
+
"text/plain": [
|
| 526 |
+
"Creating parquet from Arrow format: 0%| | 0/20 [00:00<?, ?ba/s]"
|
| 527 |
+
]
|
| 528 |
+
},
|
| 529 |
+
"metadata": {},
|
| 530 |
+
"output_type": "display_data"
|
| 531 |
+
},
|
| 532 |
+
{
|
| 533 |
+
"data": {
|
| 534 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 535 |
+
"model_id": "d887773bdb354eec83ef2d2a7f135c97",
|
| 536 |
+
"version_major": 2,
|
| 537 |
+
"version_minor": 0
|
| 538 |
+
},
|
| 539 |
+
"text/plain": [
|
| 540 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 541 |
+
]
|
| 542 |
+
},
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"output_type": "display_data"
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"data": {
|
| 548 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 549 |
+
"model_id": "2405e0ad01fb4117b4a90a21b764a91e",
|
| 550 |
+
"version_major": 2,
|
| 551 |
+
"version_minor": 0
|
| 552 |
+
},
|
| 553 |
+
"text/plain": [
|
| 554 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
"metadata": {},
|
| 558 |
+
"output_type": "display_data"
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"data": {
|
| 562 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 563 |
+
"model_id": "d0583b97bdb4403185692e70f2e3eb8e",
|
| 564 |
+
"version_major": 2,
|
| 565 |
+
"version_minor": 0
|
| 566 |
+
},
|
| 567 |
+
"text/plain": [
|
| 568 |
+
"Creating parquet from Arrow format: 0%| | 0/19 [00:00<?, ?ba/s]"
|
| 569 |
+
]
|
| 570 |
+
},
|
| 571 |
+
"metadata": {},
|
| 572 |
+
"output_type": "display_data"
|
| 573 |
+
},
|
| 574 |
+
{
|
| 575 |
+
"data": {
|
| 576 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 577 |
+
"model_id": "c33b9b1c9cde4256a5b3840830a29628",
|
| 578 |
+
"version_major": 2,
|
| 579 |
+
"version_minor": 0
|
| 580 |
+
},
|
| 581 |
+
"text/plain": [
|
| 582 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 583 |
+
]
|
| 584 |
+
},
|
| 585 |
+
"metadata": {},
|
| 586 |
+
"output_type": "display_data"
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"data": {
|
| 590 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 591 |
+
"model_id": "02687da16cf0401ea4f19a89e2e7ac9c",
|
| 592 |
+
"version_major": 2,
|
| 593 |
+
"version_minor": 0
|
| 594 |
+
},
|
| 595 |
+
"text/plain": [
|
| 596 |
+
"Downloading metadata: 0%| | 0.00/752 [00:00<?, ?B/s]"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
"metadata": {},
|
| 600 |
+
"output_type": "display_data"
|
| 601 |
+
}
|
| 602 |
+
],
|
| 603 |
+
"source": [
|
| 604 |
+
"dataset.push_to_hub(hf_repo_name)"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"cell_type": "markdown",
|
| 609 |
+
"id": "767a4251-fab6-47ce-8cdc-e2416d70b440",
|
| 610 |
+
"metadata": {},
|
| 611 |
+
"source": [
|
| 612 |
+
"### Prepare dataset for finetuning\n",
|
| 613 |
+
"[Docs](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune)\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"Format:\n",
|
| 616 |
+
"```json\n",
|
| 617 |
+
"{\"query\": str, \"pos\": List[str], \"neg\":List[str]}\n",
|
| 618 |
+
"```\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"Keys:\n",
|
| 621 |
+
"- query: belief\n",
|
| 622 |
+
"- pos: list of matching conversations\n",
|
| 623 |
+
"- neg: list of random conversations from dataset"
|
| 624 |
+
]
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"cell_type": "code",
|
| 628 |
+
"execution_count": 4,
|
| 629 |
+
"id": "ea1f2c3c-211d-4740-be1b-5eac3f57416c",
|
| 630 |
+
"metadata": {},
|
| 631 |
+
"outputs": [
|
| 632 |
+
{
|
| 633 |
+
"data": {
|
| 634 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 635 |
+
"model_id": "bc180cbde424436193fbaef12800d924",
|
| 636 |
+
"version_major": 2,
|
| 637 |
+
"version_minor": 0
|
| 638 |
+
},
|
| 639 |
+
"text/plain": [
|
| 640 |
+
"Downloading readme: 0%| | 0.00/752 [00:00<?, ?B/s]"
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
"metadata": {},
|
| 644 |
+
"output_type": "display_data"
|
| 645 |
+
},
|
| 646 |
+
{
|
| 647 |
+
"data": {
|
| 648 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 649 |
+
"model_id": "4e8a26389c2b4ed6ba6be892cf0c594d",
|
| 650 |
+
"version_major": 2,
|
| 651 |
+
"version_minor": 0
|
| 652 |
+
},
|
| 653 |
+
"text/plain": [
|
| 654 |
+
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
|
| 655 |
+
]
|
| 656 |
+
},
|
| 657 |
+
"metadata": {},
|
| 658 |
+
"output_type": "display_data"
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"data": {
|
| 662 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 663 |
+
"model_id": "dfbe638ec7424b6d99552ccf00d3703b",
|
| 664 |
+
"version_major": 2,
|
| 665 |
+
"version_minor": 0
|
| 666 |
+
},
|
| 667 |
+
"text/plain": [
|
| 668 |
+
"Downloading data: 0%| | 0.00/81.5M [00:00<?, ?B/s]"
|
| 669 |
+
]
|
| 670 |
+
},
|
| 671 |
+
"metadata": {},
|
| 672 |
+
"output_type": "display_data"
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
"data": {
|
| 676 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 677 |
+
"model_id": "b88de79e28d643f38ab0b3def530c008",
|
| 678 |
+
"version_major": 2,
|
| 679 |
+
"version_minor": 0
|
| 680 |
+
},
|
| 681 |
+
"text/plain": [
|
| 682 |
+
"Downloading data: 0%| | 0.00/3.91M [00:00<?, ?B/s]"
|
| 683 |
+
]
|
| 684 |
+
},
|
| 685 |
+
"metadata": {},
|
| 686 |
+
"output_type": "display_data"
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
"data": {
|
| 690 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 691 |
+
"model_id": "e9daf99846404246ab60f02caabf66ef",
|
| 692 |
+
"version_major": 2,
|
| 693 |
+
"version_minor": 0
|
| 694 |
+
},
|
| 695 |
+
"text/plain": [
|
| 696 |
+
"Downloading data: 0%| | 0.00/3.84M [00:00<?, ?B/s]"
|
| 697 |
+
]
|
| 698 |
+
},
|
| 699 |
+
"metadata": {},
|
| 700 |
+
"output_type": "display_data"
|
| 701 |
+
},
|
| 702 |
+
{
|
| 703 |
+
"data": {
|
| 704 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 705 |
+
"model_id": "6c441707e6b04ff8ba0a3d68790eced7",
|
| 706 |
+
"version_major": 2,
|
| 707 |
+
"version_minor": 0
|
| 708 |
+
},
|
| 709 |
+
"text/plain": [
|
| 710 |
+
"Extracting data files: 0%| | 0/3 [00:00<?, ?it/s]"
|
| 711 |
+
]
|
| 712 |
+
},
|
| 713 |
+
"metadata": {},
|
| 714 |
+
"output_type": "display_data"
|
| 715 |
+
},
|
| 716 |
+
{
|
| 717 |
+
"data": {
|
| 718 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 719 |
+
"model_id": "29fe824823394171b01394b813cabb1e",
|
| 720 |
+
"version_major": 2,
|
| 721 |
+
"version_minor": 0
|
| 722 |
+
},
|
| 723 |
+
"text/plain": [
|
| 724 |
+
"Generating train split: 0%| | 0/338127 [00:00<?, ? examples/s]"
|
| 725 |
+
]
|
| 726 |
+
},
|
| 727 |
+
"metadata": {},
|
| 728 |
+
"output_type": "display_data"
|
| 729 |
+
},
|
| 730 |
+
{
|
| 731 |
+
"data": {
|
| 732 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 733 |
+
"model_id": "04895f3438bd40e6b3e34c2e7934f920",
|
| 734 |
+
"version_major": 2,
|
| 735 |
+
"version_minor": 0
|
| 736 |
+
},
|
| 737 |
+
"text/plain": [
|
| 738 |
+
"Generating validation split: 0%| | 0/19131 [00:00<?, ? examples/s]"
|
| 739 |
+
]
|
| 740 |
+
},
|
| 741 |
+
"metadata": {},
|
| 742 |
+
"output_type": "display_data"
|
| 743 |
+
},
|
| 744 |
+
{
|
| 745 |
+
"data": {
|
| 746 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 747 |
+
"model_id": "4c21386349cc4e15b259f3b462fe8d9a",
|
| 748 |
+
"version_major": 2,
|
| 749 |
+
"version_minor": 0
|
| 750 |
+
},
|
| 751 |
+
"text/plain": [
|
| 752 |
+
"Generating test split: 0%| | 0/18381 [00:00<?, ? examples/s]"
|
| 753 |
+
]
|
| 754 |
+
},
|
| 755 |
+
"metadata": {},
|
| 756 |
+
"output_type": "display_data"
|
| 757 |
+
}
|
| 758 |
+
],
|
| 759 |
+
"source": [
|
| 760 |
+
"dataset = load_dataset(hf_repo_name)"
|
| 761 |
+
]
|
| 762 |
+
},
|
| 763 |
+
{
|
| 764 |
+
"cell_type": "code",
|
| 765 |
+
"execution_count": 5,
|
| 766 |
+
"id": "10817e24-a6b5-49da-b1e7-6101b32a9135",
|
| 767 |
+
"metadata": {},
|
| 768 |
+
"outputs": [],
|
| 769 |
+
"source": [
|
| 770 |
+
"def pick_random(dataset, split=\"train\", far_from=0):\n",
|
| 771 |
+
" ds = dataset[split]\n",
|
| 772 |
+
" ds_len = len(ds)\n",
|
| 773 |
+
" mid = ds_len // 2\n",
|
| 774 |
+
" which_half = far_from // mid\n",
|
| 775 |
+
" \n",
|
| 776 |
+
" start = (1 - which_half) * mid\n",
|
| 777 |
+
" end = ds_len - which_half * mid\n",
|
| 778 |
+
" idx = random.randrange(start, end)\n",
|
| 779 |
+
" \n",
|
| 780 |
+
" return ds[idx]"
|
| 781 |
+
]
|
| 782 |
+
},
|
| 783 |
+
{
|
| 784 |
+
"cell_type": "code",
|
| 785 |
+
"execution_count": 6,
|
| 786 |
+
"id": "9bf3bf97-86c4-41f4-ab07-7de94ed72344",
|
| 787 |
+
"metadata": {},
|
| 788 |
+
"outputs": [
|
| 789 |
+
{
|
| 790 |
+
"data": {
|
| 791 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 792 |
+
"model_id": "6c031292641d4e668428e227d0cb22e5",
|
| 793 |
+
"version_major": 2,
|
| 794 |
+
"version_minor": 0
|
| 795 |
+
},
|
| 796 |
+
"text/plain": [
|
| 797 |
+
" 0%| | 0/338127 [00:00<?, ?it/s]"
|
| 798 |
+
]
|
| 799 |
+
},
|
| 800 |
+
"metadata": {},
|
| 801 |
+
"output_type": "display_data"
|
| 802 |
+
}
|
| 803 |
+
],
|
| 804 |
+
"source": [
|
| 805 |
+
"with jsonl.open(training_input_file, mode='w') as writer:\n",
|
| 806 |
+
" for i, row in enumerate(tqdm(dataset[\"train\"], total=len(dataset[\"train\"]))):\n",
|
| 807 |
+
" query = row[\"summary\"]\n",
|
| 808 |
+
" pos = [row[\"dialogue\"]]\n",
|
| 809 |
+
" \n",
|
| 810 |
+
" neg = [\n",
|
| 811 |
+
" pick_random(dataset, split=\"train\", far_from=i)[\"dialogue\"]\n",
|
| 812 |
+
" for _ in range(3)\n",
|
| 813 |
+
" ]\n",
|
| 814 |
+
" \n",
|
| 815 |
+
" writer.write(dict(query=query, pos=pos, neg=neg))"
|
| 816 |
+
]
|
| 817 |
+
},
|
| 818 |
+
{
|
| 819 |
+
"cell_type": "code",
|
| 820 |
+
"execution_count": 7,
|
| 821 |
+
"id": "e07bc44f-302c-4c7c-b7c6-62c9cd9db3e4",
|
| 822 |
+
"metadata": {},
|
| 823 |
+
"outputs": [
|
| 824 |
+
{
|
| 825 |
+
"data": {
|
| 826 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 827 |
+
"model_id": "5394a57e7aca4e9d9c36b2f7f3b9b0f3",
|
| 828 |
+
"version_major": 2,
|
| 829 |
+
"version_minor": 0
|
| 830 |
+
},
|
| 831 |
+
"text/plain": [
|
| 832 |
+
" 0%| | 0/12500 [00:00<?, ?it/s]"
|
| 833 |
+
]
|
| 834 |
+
},
|
| 835 |
+
"metadata": {},
|
| 836 |
+
"output_type": "display_data"
|
| 837 |
+
}
|
| 838 |
+
],
|
| 839 |
+
"source": [
|
| 840 |
+
"with jsonl.open(eval_input_file, mode='w') as writer:\n",
|
| 841 |
+
" for i, row in enumerate(tqdm(dataset[\"validation\"], total=eval_size)):\n",
|
| 842 |
+
" if i > eval_size:\n",
|
| 843 |
+
" break\n",
|
| 844 |
+
"\n",
|
| 845 |
+
" query = row[\"summary\"]\n",
|
| 846 |
+
" pos = [row[\"dialogue\"]]\n",
|
| 847 |
+
" \n",
|
| 848 |
+
" neg = [\n",
|
| 849 |
+
" pick_random(dataset, split=\"validation\", far_from=i)[\"dialogue\"]\n",
|
| 850 |
+
" for _ in range(3)\n",
|
| 851 |
+
" ]\n",
|
| 852 |
+
" \n",
|
| 853 |
+
" writer.write(dict(query=query, pos=pos, neg=neg))"
|
| 854 |
+
]
|
| 855 |
+
},
|
| 856 |
+
{
|
| 857 |
+
"cell_type": "markdown",
|
| 858 |
+
"id": "b6c895f9-9ef4-4edc-b65d-722188eaa8bd",
|
| 859 |
+
"metadata": {},
|
| 860 |
+
"source": [
|
| 861 |
+
"### Mine hard negatives"
|
| 862 |
+
]
|
| 863 |
+
},
|
| 864 |
+
{
|
| 865 |
+
"cell_type": "code",
|
| 866 |
+
"execution_count": 9,
|
| 867 |
+
"id": "b73cf693-4138-429f-8188-0a72b36ed44b",
|
| 868 |
+
"metadata": {},
|
| 869 |
+
"outputs": [],
|
| 870 |
+
"source": [
|
| 871 |
+
"model = FlagModel(\n",
|
| 872 |
+
" model_name,\n",
|
| 873 |
+
" query_instruction_for_retrieval=query_prefix,\n",
|
| 874 |
+
")"
|
| 875 |
+
]
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
"cell_type": "code",
|
| 879 |
+
"execution_count": 10,
|
| 880 |
+
"id": "adc677e6-c28f-49f9-a812-5cd4e93084b3",
|
| 881 |
+
"metadata": {},
|
| 882 |
+
"outputs": [
|
| 883 |
+
{
|
| 884 |
+
"name": "stdout",
|
| 885 |
+
"output_type": "stream",
|
| 886 |
+
"text": [
|
| 887 |
+
"inferencing embedding for corpus (number=37361)--------------\n"
|
| 888 |
+
]
|
| 889 |
+
},
|
| 890 |
+
{
|
| 891 |
+
"name": "stderr",
|
| 892 |
+
"output_type": "stream",
|
| 893 |
+
"text": [
|
| 894 |
+
"Inference Embeddings: 100%|██████████| 146/146 [00:37<00:00, 3.87it/s]\n"
|
| 895 |
+
]
|
| 896 |
+
},
|
| 897 |
+
{
|
| 898 |
+
"name": "stdout",
|
| 899 |
+
"output_type": "stream",
|
| 900 |
+
"text": [
|
| 901 |
+
"inferencing embedding for queries (number=338127)--------------\n"
|
| 902 |
+
]
|
| 903 |
+
},
|
| 904 |
+
{
|
| 905 |
+
"name": "stderr",
|
| 906 |
+
"output_type": "stream",
|
| 907 |
+
"text": [
|
| 908 |
+
"Inference Embeddings: 100%|██████████| 1321/1321 [00:52<00:00, 25.34it/s]\n"
|
| 909 |
+
]
|
| 910 |
+
},
|
| 911 |
+
{
|
| 912 |
+
"name": "stdout",
|
| 913 |
+
"output_type": "stream",
|
| 914 |
+
"text": [
|
| 915 |
+
"create index and search------------------\n"
|
| 916 |
+
]
|
| 917 |
+
},
|
| 918 |
+
{
|
| 919 |
+
"name": "stderr",
|
| 920 |
+
"output_type": "stream",
|
| 921 |
+
"text": [
|
| 922 |
+
"Batches: 100%|██████████| 5284/5284 [00:07<00:00, 740.63it/s]\n"
|
| 923 |
+
]
|
| 924 |
+
}
|
| 925 |
+
],
|
| 926 |
+
"source": [
|
| 927 |
+
"find_knn_neg(\n",
|
| 928 |
+
" model,\n",
|
| 929 |
+
" input_file=training_input_file,\n",
|
| 930 |
+
" candidate_pool=None,\n",
|
| 931 |
+
" output_file=training_hn_file,\n",
|
| 932 |
+
" sample_range=list(range(2, 200)),\n",
|
| 933 |
+
" negative_number=10,\n",
|
| 934 |
+
" use_gpu=True,\n",
|
| 935 |
+
")"
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "markdown",
|
| 940 |
+
"id": "d408f52e-d8b8-4e6a-86bc-234d2b862a86",
|
| 941 |
+
"metadata": {},
|
| 942 |
+
"source": [
|
| 943 |
+
"### Add processed files to hf dataset"
|
| 944 |
+
]
|
| 945 |
+
},
|
| 946 |
+
{
|
| 947 |
+
"cell_type": "code",
|
| 948 |
+
"execution_count": 11,
|
| 949 |
+
"id": "fd79a43e-7add-4037-9b5f-5bf60db89158",
|
| 950 |
+
"metadata": {},
|
| 951 |
+
"outputs": [
|
| 952 |
+
{
|
| 953 |
+
"data": {
|
| 954 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 955 |
+
"model_id": "a8be41b80f2b42c8800eb12d0ec57bf9",
|
| 956 |
+
"version_major": 2,
|
| 957 |
+
"version_minor": 0
|
| 958 |
+
},
|
| 959 |
+
"text/plain": [
|
| 960 |
+
"train.jsonl: 0%| | 0.00/2.42G [00:00<?, ?B/s]"
|
| 961 |
+
]
|
| 962 |
+
},
|
| 963 |
+
"metadata": {},
|
| 964 |
+
"output_type": "display_data"
|
| 965 |
+
}
|
| 966 |
+
],
|
| 967 |
+
"source": [
|
| 968 |
+
"hf_api = HfApi()\n",
|
| 969 |
+
"\n",
|
| 970 |
+
"for path in [\n",
|
| 971 |
+
" training_input_file,\n",
|
| 972 |
+
" eval_input_file,\n",
|
| 973 |
+
" training_hn_file,\n",
|
| 974 |
+
"]:\n",
|
| 975 |
+
" hf_api.upload_file(\n",
|
| 976 |
+
" path_or_fileobj=path,\n",
|
| 977 |
+
" path_in_repo=path.split('/')[-1],\n",
|
| 978 |
+
" repo_id=hf_repo_name,\n",
|
| 979 |
+
" repo_type=\"dataset\",\n",
|
| 980 |
+
" )\n"
|
| 981 |
+
]
|
| 982 |
+
},
|
| 983 |
+
{
|
| 984 |
+
"cell_type": "code",
|
| 985 |
+
"execution_count": null,
|
| 986 |
+
"id": "78410dd0-80c9-4f27-9a95-4b8a34604e1e",
|
| 987 |
+
"metadata": {},
|
| 988 |
+
"outputs": [],
|
| 989 |
+
"source": []
|
| 990 |
+
}
|
| 991 |
+
],
|
| 992 |
+
"metadata": {
|
| 993 |
+
"kernelspec": {
|
| 994 |
+
"display_name": "Python 3 (ipykernel)",
|
| 995 |
+
"language": "python",
|
| 996 |
+
"name": "python3"
|
| 997 |
+
},
|
| 998 |
+
"language_info": {
|
| 999 |
+
"codemirror_mode": {
|
| 1000 |
+
"name": "ipython",
|
| 1001 |
+
"version": 3
|
| 1002 |
+
},
|
| 1003 |
+
"file_extension": ".py",
|
| 1004 |
+
"mimetype": "text/x-python",
|
| 1005 |
+
"name": "python",
|
| 1006 |
+
"nbconvert_exporter": "python",
|
| 1007 |
+
"pygments_lexer": "ipython3",
|
| 1008 |
+
"version": "3.10.6"
|
| 1009 |
+
}
|
| 1010 |
+
},
|
| 1011 |
+
"nbformat": 4,
|
| 1012 |
+
"nbformat_minor": 5
|
| 1013 |
+
}
|
data_prep.pdf
ADDED
|
Binary file (63.1 kB). View file
|
|
|
training.ipynb
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ec403ba5-1356-46b7-a14f-86bf7db0c5b4",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Train Dialog-Fact Encoder\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Goal:** Train an embedding model to match dialogs with (possibly) relevant facts "
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "723a9f8a-800a-4de0-ab89-e4d984271a5b",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"source": [
|
| 18 |
+
"### Constants"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 1,
|
| 24 |
+
"id": "7167d6e4-7a7f-4f7f-b4e7-92b9613afed8",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"model_name = \"BAAI/bge-base-en-v1.5\"\n",
|
| 29 |
+
"query_prefix = \"Represent this sentence for searching relevant passages: \"\n",
|
| 30 |
+
"max_len = 512\n",
|
| 31 |
+
"training_hn_file = \"./data/train.jsonl\"\n",
|
| 32 |
+
"eval_file = \"./data/eval.jsonl\"\n",
|
| 33 |
+
"batch_size = 1350\n",
|
| 34 |
+
"output_model_path = \"./dfe-base-en\"\n",
|
| 35 |
+
"hf_repo_name = \"julep-ai/dfe-base-en\""
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "markdown",
|
| 40 |
+
"id": "22aad488-38c3-40b9-8e5b-6d47b41d49cf",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"source": [
|
| 43 |
+
"### Imports"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"id": "98d5e97e-df3b-43e4-b82c-2f4768a217b6",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"import itertools as it\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"import graphviz\n",
|
| 56 |
+
"import jsonlines as jsonl\n",
|
| 57 |
+
"from lion_pytorch import Lion\n",
|
| 58 |
+
"from sentence_transformers import InputExample, SentenceTransformer, losses as ls, models as ml, util\n",
|
| 59 |
+
"from sentence_transformers.evaluation import SimilarityFunction, TripletEvaluator\n",
|
| 60 |
+
"import torch\n",
|
| 61 |
+
"from torch.utils.data import DataLoader, IterableDataset\n",
|
| 62 |
+
"from tqdm.auto import tqdm"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"id": "72ee0c6c-2785-49ff-85ec-600b76af11b8",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"source": [
|
| 70 |
+
"### Dataset"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": 3,
|
| 76 |
+
"id": "b17def02-f756-4973-a29f-dd628da34e58",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"def hn_output(file):\n",
|
| 81 |
+
" with jsonl.open(file) as reader:\n",
|
| 82 |
+
" for entry in reader:\n",
|
| 83 |
+
" query = entry[\"query\"]\n",
|
| 84 |
+
" pos = [dict(dialog=dialog) for dialog in entry[\"pos\"]]\n",
|
| 85 |
+
" neg = [dict(dialog=dialog) for dialog in entry[\"neg\"]]\n",
|
| 86 |
+
"\n",
|
| 87 |
+
" for combined in it.product(\n",
|
| 88 |
+
" [dict(fact=query)],\n",
|
| 89 |
+
" pos,\n",
|
| 90 |
+
" neg,\n",
|
| 91 |
+
" ):\n",
|
| 92 |
+
" yield InputExample(texts=list(combined))"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": 4,
|
| 98 |
+
"id": "34649f83-5bc3-4b1b-a1b2-3d406b84979d",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [
|
| 101 |
+
{
|
| 102 |
+
"data": {
|
| 103 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 104 |
+
"model_id": "01107f542dec483a9a48ed4b9e4b9a76",
|
| 105 |
+
"version_major": 2,
|
| 106 |
+
"version_minor": 0
|
| 107 |
+
},
|
| 108 |
+
"text/plain": [
|
| 109 |
+
"0it [00:00, ?it/s]"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"output_type": "display_data"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"data": {
|
| 117 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 118 |
+
"model_id": "039f46c46d724fa0aac242492248dbff",
|
| 119 |
+
"version_major": 2,
|
| 120 |
+
"version_minor": 0
|
| 121 |
+
},
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"0it [00:00, ?it/s]"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"source": [
|
| 131 |
+
"training_data = list(tqdm(hn_output(training_hn_file)))\n",
|
| 132 |
+
"eval_data = list(tqdm(hn_output(eval_file)))"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": 5,
|
| 138 |
+
"id": "8e817f20-4e80-4842-bf45-f7439a5e2b7a",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"dataloader = DataLoader(training_data, shuffle=True, batch_size=batch_size)\n",
|
| 143 |
+
"eval_dataloader = DataLoader(eval_data, shuffle=True, batch_size=batch_size // 10)"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "markdown",
|
| 148 |
+
"id": "be0a103c-1c3d-41fa-933c-f0b843087658",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"### DFE Model Architecture"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": 6,
|
| 157 |
+
"id": "c8eea066-1f4e-4184-9215-0b5fdd1cdf16",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"outputs": [],
|
| 160 |
+
"source": [
|
| 161 |
+
"# Base model\n",
|
| 162 |
+
"base_model = SentenceTransformer(model_name)"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"cell_type": "code",
|
| 167 |
+
"execution_count": 7,
|
| 168 |
+
"id": "7f31eda8-d224-4d30-8a6b-ed4cb32a2c12",
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"outputs": [],
|
| 171 |
+
"source": [
|
| 172 |
+
"# Freeze base transformer layers\n",
|
| 173 |
+
"for param in base_model.parameters():\n",
|
| 174 |
+
" param.requires_grad = False"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": 8,
|
| 180 |
+
"id": "721c3897-9ef0-409f-9e9d-a693975486bf",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"outputs": [],
|
| 183 |
+
"source": [
|
| 184 |
+
"device = torch.device(\"cuda:0\")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset\n",
|
| 187 |
+
"# the body location\n",
|
| 188 |
+
"base_model._target_device = device\n",
|
| 189 |
+
"base_model = base_model.to(device)"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"execution_count": 9,
|
| 195 |
+
"id": "6115d96b-fe35-4a23-9a21-f3da52304f3a",
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"outputs": [],
|
| 198 |
+
"source": [
|
| 199 |
+
"emb_dims = base_model._first_module().get_word_embedding_dimension() # 768\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"def dense_projector(dims: int):\n",
|
| 202 |
+
" proj_dims = dims * 2 # 1536\n",
|
| 203 |
+
" \n",
|
| 204 |
+
" return [\n",
|
| 205 |
+
" ml.Dense(dims, proj_dims), # 768 -> 1536\n",
|
| 206 |
+
" ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n",
|
| 207 |
+
" ml.Dropout(0.1),\n",
|
| 208 |
+
" ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n",
|
| 209 |
+
" ml.Dense(proj_dims, dims), # 1536 -> 768\n",
|
| 210 |
+
" ]\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"def asym_module(dims: int, keys: list[str], allow_empty_key: bool = False):\n",
|
| 213 |
+
" return ml.Asym(\n",
|
| 214 |
+
" {\n",
|
| 215 |
+
" key: dense_projector(dims)\n",
|
| 216 |
+
" for key in keys\n",
|
| 217 |
+
" },\n",
|
| 218 |
+
" allow_empty_key=allow_empty_key,\n",
|
| 219 |
+
" )"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": 10,
|
| 225 |
+
"id": "2b273b52-b3b1-4f29-9d9a-1fe00d29c686",
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"outputs": [],
|
| 228 |
+
"source": [
|
| 229 |
+
"base_model._modules[\"2\"] = asym_module(emb_dims, [\"dialog\", \"fact\"])"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "code",
|
| 234 |
+
"execution_count": 11,
|
| 235 |
+
"id": "03004002-b9d1-4b71-8ea5-bd2a2072c751",
|
| 236 |
+
"metadata": {},
|
| 237 |
+
"outputs": [
|
| 238 |
+
{
|
| 239 |
+
"data": {
|
| 240 |
+
"text/plain": [
|
| 241 |
+
"OrderedDict([('0',\n",
|
| 242 |
+
" Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),\n",
|
| 243 |
+
" ('1',\n",
|
| 244 |
+
" Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})),\n",
|
| 245 |
+
" ('2',\n",
|
| 246 |
+
" Asym(\n",
|
| 247 |
+
" (dialog-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 248 |
+
" (dialog-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 249 |
+
" (dialog-2): Dropout(\n",
|
| 250 |
+
" (dropout_layer): Dropout(p=0.1, inplace=False)\n",
|
| 251 |
+
" )\n",
|
| 252 |
+
" (dialog-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 253 |
+
" (dialog-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 254 |
+
" (fact-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 255 |
+
" (fact-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 256 |
+
" (fact-2): Dropout(\n",
|
| 257 |
+
" (dropout_layer): Dropout(p=0.1, inplace=False)\n",
|
| 258 |
+
" )\n",
|
| 259 |
+
" (fact-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 260 |
+
" (fact-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
| 261 |
+
" ))])"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
"execution_count": 11,
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"output_type": "execute_result"
|
| 267 |
+
}
|
| 268 |
+
],
|
| 269 |
+
"source": [
|
| 270 |
+
"base_model._modules"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"cell_type": "markdown",
|
| 275 |
+
"id": "6ea33246-2612-443d-a5c0-4179eea1a126",
|
| 276 |
+
"metadata": {},
|
| 277 |
+
"source": [
|
| 278 |
+
"### Prepare training loss and evaluator"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": 12,
|
| 284 |
+
"id": "e0008a08-a08d-4523-b477-212083a93aa8",
|
| 285 |
+
"metadata": {},
|
| 286 |
+
"outputs": [],
|
| 287 |
+
"source": [
|
| 288 |
+
"train_loss = ls.TripletLoss(model=base_model)"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": 13,
|
| 294 |
+
"id": "53b0aba9-a279-4c90-8949-e0096b5ed4c7",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [],
|
| 297 |
+
"source": [
|
| 298 |
+
"triplet_evaluator = TripletEvaluator.from_input_examples(\n",
|
| 299 |
+
" eval_data, # Triplet is ({dialog: <some_dialog>}, {fact: <relevant_fact>}, [{fact: <negative_irrelevant_fact>}])\n",
|
| 300 |
+
" batch_size=batch_size // 10,\n",
|
| 301 |
+
" main_distance_function=SimilarityFunction.COSINE,\n",
|
| 302 |
+
" show_progress_bar=True,\n",
|
| 303 |
+
" write_csv=True,\n",
|
| 304 |
+
")"
|
| 305 |
+
]
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"cell_type": "markdown",
|
| 309 |
+
"id": "a6ea59f8-c1e1-404b-ba84-95c8199cd1df",
|
| 310 |
+
"metadata": {},
|
| 311 |
+
"source": [
|
| 312 |
+
"### Train model"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": null,
|
| 318 |
+
"id": "dbf3b8c9-8ef8-4198-b284-910c57f2cbca",
|
| 319 |
+
"metadata": {},
|
| 320 |
+
"outputs": [
|
| 321 |
+
{
|
| 322 |
+
"data": {
|
| 323 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 324 |
+
"model_id": "ea0ed014f83b4651b810c0abd317add9",
|
| 325 |
+
"version_major": 2,
|
| 326 |
+
"version_minor": 0
|
| 327 |
+
},
|
| 328 |
+
"text/plain": [
|
| 329 |
+
"Epoch: 0%| | 0/15 [00:00<?, ?it/s]"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"output_type": "display_data"
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"data": {
|
| 337 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 338 |
+
"model_id": "5690514fe3ac4e3a84fedb128a687ec1",
|
| 339 |
+
"version_major": 2,
|
| 340 |
+
"version_minor": 0
|
| 341 |
+
},
|
| 342 |
+
"text/plain": [
|
| 343 |
+
"Iteration: 0%| | 0/2505 [00:00<?, ?it/s]"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
"metadata": {},
|
| 347 |
+
"output_type": "display_data"
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"data": {
|
| 351 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 352 |
+
"model_id": "ef19638fe2504ec095fa9f6aed3d5069",
|
| 353 |
+
"version_major": 2,
|
| 354 |
+
"version_minor": 0
|
| 355 |
+
},
|
| 356 |
+
"text/plain": [
|
| 357 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"output_type": "display_data"
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"data": {
|
| 365 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 366 |
+
"model_id": "62af9e297e044f0bbb4d9b903b229db4",
|
| 367 |
+
"version_major": 2,
|
| 368 |
+
"version_minor": 0
|
| 369 |
+
},
|
| 370 |
+
"text/plain": [
|
| 371 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"output_type": "display_data"
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"data": {
|
| 379 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 380 |
+
"model_id": "4007c6d76e1445fe8263561886ab3196",
|
| 381 |
+
"version_major": 2,
|
| 382 |
+
"version_minor": 0
|
| 383 |
+
},
|
| 384 |
+
"text/plain": [
|
| 385 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"output_type": "display_data"
|
| 390 |
+
}
|
| 391 |
+
],
|
| 392 |
+
"source": [
|
| 393 |
+
"base_model.fit(\n",
|
| 394 |
+
" train_objectives=[(dataloader, train_loss)],\n",
|
| 395 |
+
" evaluator=triplet_evaluator,\n",
|
| 396 |
+
" checkpoint_save_steps=600,\n",
|
| 397 |
+
" evaluation_steps=600,\n",
|
| 398 |
+
" checkpoint_path=f\"{output_model_path}/ckpts\",\n",
|
| 399 |
+
" scheduler=\"WarmupCosine\",\n",
|
| 400 |
+
" save_best_model=True,\n",
|
| 401 |
+
" epochs=15,\n",
|
| 402 |
+
" warmup_steps=200,\n",
|
| 403 |
+
" optimizer_class=Lion,\n",
|
| 404 |
+
" optimizer_params=dict(lr=1e-4, weight_decay=1e-2),\n",
|
| 405 |
+
" use_amp=True,\n",
|
| 406 |
+
" output_path=output_model_path,\n",
|
| 407 |
+
" checkpoint_save_total_limit=4,\n",
|
| 408 |
+
")"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"cell_type": "code",
|
| 413 |
+
"execution_count": null,
|
| 414 |
+
"id": "21c91b44-4c0a-4fda-a72c-91dac70e72ae",
|
| 415 |
+
"metadata": {},
|
| 416 |
+
"outputs": [],
|
| 417 |
+
"source": [
|
| 418 |
+
"base_model.push_to_hub(hf_repo_name)"
|
| 419 |
+
]
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"cell_type": "code",
|
| 423 |
+
"execution_count": null,
|
| 424 |
+
"id": "85e7c7cd-6636-42d2-aec4-56f292ea8ba9",
|
| 425 |
+
"metadata": {},
|
| 426 |
+
"outputs": [],
|
| 427 |
+
"source": [
|
| 428 |
+
"graphviz.set_jupyter_format('png')"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"cell_type": "code",
|
| 433 |
+
"execution_count": null,
|
| 434 |
+
"id": "bb0419cc-beb7-443e-b733-47c5b6cb267c",
|
| 435 |
+
"metadata": {},
|
| 436 |
+
"outputs": [],
|
| 437 |
+
"source": [
|
| 438 |
+
"model_graph = draw_graph(base_model, input_size=(1, 512), device='meta')\n",
|
| 439 |
+
"model_graph.visual_graph"
|
| 440 |
+
]
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"cell_type": "code",
|
| 444 |
+
"execution_count": null,
|
| 445 |
+
"id": "0e478f64-f687-40e5-a315-225de31d6df6",
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"outputs": [],
|
| 448 |
+
"source": []
|
| 449 |
+
}
|
| 450 |
+
],
|
| 451 |
+
"metadata": {
|
| 452 |
+
"kernelspec": {
|
| 453 |
+
"display_name": "Python 3 (ipykernel)",
|
| 454 |
+
"language": "python",
|
| 455 |
+
"name": "python3"
|
| 456 |
+
},
|
| 457 |
+
"language_info": {
|
| 458 |
+
"codemirror_mode": {
|
| 459 |
+
"name": "ipython",
|
| 460 |
+
"version": 3
|
| 461 |
+
},
|
| 462 |
+
"file_extension": ".py",
|
| 463 |
+
"mimetype": "text/x-python",
|
| 464 |
+
"name": "python",
|
| 465 |
+
"nbconvert_exporter": "python",
|
| 466 |
+
"pygments_lexer": "ipython3",
|
| 467 |
+
"version": "3.10.6"
|
| 468 |
+
}
|
| 469 |
+
},
|
| 470 |
+
"nbformat": 4,
|
| 471 |
+
"nbformat_minor": 5
|
| 472 |
+
}
|
training.pdf
ADDED
|
Binary file (46.1 kB). View file
|
|
|