From 2dfb52fb8772fae258a970691485915f20eb77d4 Mon Sep 17 00:00:00 2001 From: System Administrator Date: Thu, 1 Oct 2020 19:28:10 -0700 Subject: [PATCH] Example file for recall at k --- examples/example_recall_at_k.ipynb | 839 +++++++++++++++++++++++++++++ 1 file changed, 839 insertions(+) create mode 100644 examples/example_recall_at_k.ipynb diff --git a/examples/example_recall_at_k.ipynb b/examples/example_recall_at_k.ipynb new file mode 100644 index 0000000..c6614aa --- /dev/null +++ b/examples/example_recall_at_k.ipynb @@ -0,0 +1,839 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Example metrix.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "CMI1mfyW34c3" + }, + "source": [ + "# Example: How to call recall_at_k for a binary classification ranking model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9djGHUjV3o8g" + }, + "source": [ + "In this example we create a model to rank LETOR04 MQ2008 dataset using xgboost XGBRanker. The below model utilizes code from https://github.com/dmlc/xgboost/tree/master/demo/rank" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwcKwSV04T1D" + }, + "source": [ + "We start with installing metriks" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Rno5yhswgox6", + "outputId": "a72b580f-06ae-4574-fc28-b83262c2eef2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "!pip install metriks" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Requirement already satisfied: metriks in /usr/local/lib/python3.6/dist-packages (0.0.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from metriks) (1.18.5)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EMa2-j1KhEKA" + }, + "source": [ + "import random\n", + "import numpy as np\n", + "import xgboost as xgb\n", + "from sklearn.datasets import load_svmlight_file\n", + "from sklearn.utils.extmath import softmax\n", + "from xgboost import DMatrix\n", + "from metriks.ranking import recall_at_k" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QIpR-zTl4m-v" + }, + "source": [ + "We download MQ2008 data, unzip the same and copy the train, test and valid from Fold1 to the current folder." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YXRZfVbnhqZT", + "outputId": "6c24189f-56f3-4023-b654-43f45300a828", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 833 + } + }, + "source": [ + "!wget https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.rar\n", + "!unrar x MQ2008.rar\n", + "!mv -f MQ2008/Fold1/*.txt ." + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "--2020-10-02 01:46:53-- https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.rar\n", + "Resolving s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)... 52.218.250.248\n", + "Connecting to s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)|52.218.250.248|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 15448795 (15M) [application/x-rar-compressed]\n", + "Saving to: ‘MQ2008.rar’\n", + "\n", + "MQ2008.rar 100%[===================>] 14.73M 9.01MB/s in 1.6s \n", + "\n", + "2020-10-02 01:46:55 (9.01 MB/s) - ‘MQ2008.rar’ saved [15448795/15448795]\n", + "\n", + "\n", + "UNRAR 5.50 freeware Copyright (c) 1993-2017 Alexander Roshal\n", + "\n", + "\n", + "Extracting from MQ2008.rar\n", + "\n", + "Creating MQ2008 OK\n", + "Creating MQ2008/Fold1 OK\n", + "Extracting MQ2008/Fold1/test.txt \b\b\b\b 0%\b\b\b\b 1%\b\b\b\b 2%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold1/train.txt \b\b\b\b 2%\b\b\b\b 3%\b\b\b\b 4%\b\b\b\b 5%\b\b\b\b 6%\b\b\b\b 7%\b\b\b\b 8%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold1/vali.txt \b\b\b\b 9%\b\b\b\b 10%\b\b\b\b\b OK \n", + "Creating MQ2008/Fold2 OK\n", + "Extracting MQ2008/Fold2/test.txt \b\b\b\b 10%\b\b\b\b 11%\b\b\b\b 12%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold2/train.txt \b\b\b\b 12%\b\b\b\b 13%\b\b\b\b 14%\b\b\b\b 15%\b\b\b\b 16%\b\b\b\b 17%\b\b\b\b 18%\b\b\b\b 19%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold2/vali.txt \b\b\b\b 19%\b\b\b\b 20%\b\b\b\b 21%\b\b\b\b\b OK \n", + "Creating MQ2008/Fold3 OK\n", + "Extracting MQ2008/Fold3/test.txt \b\b\b\b 21%\b\b\b\b 22%\b\b\b\b 23%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold3/train.txt \b\b\b\b 24%\b\b\b\b 25%\b\b\b\b 26%\b\b\b\b 27%\b\b\b\b 28%\b\b\b\b 29%\b\b\b\b 30%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold3/vali.txt \b\b\b\b 30%\b\b\b\b 31%\b\b\b\b 32%\b\b\b\b\b OK \n", + "Creating MQ2008/Fold4 OK\n", + "Extracting MQ2008/Fold4/test.txt \b\b\b\b 32%\b\b\b\b 33%\b\b\b\b 34%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold4/train.txt \b\b\b\b 34%\b\b\b\b 35%\b\b\b\b 36%\b\b\b\b 37%\b\b\b\b 38%\b\b\b\b 39%\b\b\b\b 40%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold4/vali.txt \b\b\b\b 40%\b\b\b\b 41%\b\b\b\b 42%\b\b\b\b\b OK \n", + "Creating MQ2008/Fold5 OK\n", + "Extracting MQ2008/Fold5/test.txt \b\b\b\b 43%\b\b\b\b 44%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold5/train.txt \b\b\b\b 45%\b\b\b\b 46%\b\b\b\b 47%\b\b\b\b 48%\b\b\b\b 49%\b\b\b\b 50%\b\b\b\b 51%\b\b\b\b\b OK \n", + "Extracting MQ2008/Fold5/vali.txt \b\b\b\b 51%\b\b\b\b 52%\b\b\b\b 53%\b\b\b\b\b OK \n", + "Extracting MQ2008/min.txt \b\b\b\b 53%\b\b\b\b 54%\b\b\b\b 55%\b\b\b\b 56%\b\b\b\b 57%\b\b\b\b 58%\b\b\b\b 59%\b\b\b\b 60%\b\b\b\b 61%\b\b\b\b 62%\b\b\b\b 63%\b\b\b\b 64%\b\b\b\b 65%\b\b\b\b 66%\b\b\b\b\b OK \n", + "Extracting MQ2008/NULL.txt \b\b\b\b 66%\b\b\b\b 67%\b\b\b\b 68%\b\b\b\b 69%\b\b\b\b 70%\b\b\b\b 71%\b\b\b\b 72%\b\b\b\b 73%\b\b\b\b 74%\b\b\b\b 75%\b\b\b\b 76%\b\b\b\b 77%\b\b\b\b 78%\b\b\b\b\b OK \n", + "Extracting MQ2008/Querylevelnorm.txt \b\b\b\b 78%\b\b\b\b 79%\b\b\b\b 80%\b\b\b\b 81%\b\b\b\b 82%\b\b\b\b 83%\b\b\b\b 84%\b\b\b\b 85%\b\b\b\b 86%\b\b\b\b 87%\b\b\b\b 88%\b\b\b\b 89%\b\b\b\b\b OK \n", + "Extracting MQ2008/readme.txt \b\b\b\b 89%\b\b\b\b\b OK \n", + "Extracting MQ2008/S1.txt \b\b\b\b 89%\b\b\b\b 90%\b\b\b\b 91%\b\b\b\b\b OK \n", + "Extracting MQ2008/S2.txt \b\b\b\b 91%\b\b\b\b 92%\b\b\b\b 93%\b\b\b\b\b OK \n", + "Extracting MQ2008/S3.txt \b\b\b\b 94%\b\b\b\b 95%\b\b\b\b 96%\b\b\b\b\b OK \n", + "Extracting MQ2008/S4.txt \b\b\b\b 96%\b\b\b\b 97%\b\b\b\b\b OK \n", + "Extracting MQ2008/S5.txt \b\b\b\b 98%\b\b\b\b 99%\b\b\b\b\b OK \n", + "All OK\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "I3XePZiPjDhT", + "outputId": "3c75099e-b923-4537-bf6c-4f62ca4989b2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "#confirm that train, test and vali are copied\n", + "!ls" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MQ2008\tMQ2008.rar sample_data test.txt train.txt vali.txt\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6bBu3wP5Ct4" + }, + "source": [ + "MQ2008 - Million Query Track of TREK 2008 (https://arxiv.org/abs/1306.2597)\n", + "Data structure: \n", + "Each row signifies query document pair (how relevant a document is for a query)\n", + " - Column 1: relavance label\n", + " - Column 2: Query id (We will call this as group)\n", + " - rest of the columns: features that have been normalized\n", + " - last column: comment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y9d4Z6Oc6S0z" + }, + "source": [ + "## Data grouping\n", + "Train, valid and test data are grouped as \n", + "X => contains features\n", + "y => contains labels (relevance label)\n", + "group => contains query ids" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VanGl9vkh15R" + }, + "source": [ + "#The below code is from xgboost\n", + "def save_data(group_data,output_feature,output_group):\n", + " if len(group_data) == 0:\n", + " return\n", + "\n", + " output_group.write(str(len(group_data))+\"\\n\")\n", + " for data in group_data:\n", + " # only include nonzero features\n", + " feats = [ p for p in data[2:] if float(p.split(':')[1]) != 0.0 ] \n", + " output_feature.write(data[0] + \" \" + \" \".join(feats) + \"\\n\")\n", + "\n", + "def transform(input_file_name, output_file_name, output_group_name):\n", + "\n", + " fi = open(input_file_name)\n", + " output_feature = open(output_file_name,\"w\")\n", + " output_group = open(output_group_name,\"w\")\n", + " \n", + " group_data = []\n", + " group = \"\"\n", + " for line in fi:\n", + " if not line:\n", + " break\n", + " if \"#\" in line: ##docid = GX004-93-7097963 i\n", + " line = line[:line.index(\"#\")]\n", + " splits = line.strip().split(\" \")\n", + " if splits[1] != group:\n", + " save_data(group_data,output_feature,output_group)\n", + " group_data = []\n", + " group = splits[1]\n", + " group_data.append(splits)\n", + "\n", + " save_data(group_data,output_feature,output_group)\n", + "\n", + " fi.close()\n", + " output_feature.close()\n", + " output_group.close()" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vZ233l6I6xos" + }, + "source": [ + "*transform* separates group related details from other details\n", + "group file contains number of items in each group. For instance if the entries are as below\n", + "\n", + "0, group1, features\n", + "1, group1, features\n", + "0, group1, features\n", + "0, group2, features\n", + "1, group2, features\n", + "\n", + "then group file will contain values [3,2] (3 entries of group1, 2 entries of group2)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AS7YZaC4ioqM" + }, + "source": [ + "transform(\"train.txt\", \"mq2008.train\", \"mq2008.train.group\")\n", + "transform(\"vali.txt\",\"mq2008.valid\", \"mq2008.valid.group\")\n", + "transform(\"test.txt\",\"mq2008.test\", \"mq2008.test.group\")" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PdSlzNCb7hDo" + }, + "source": [ + "relevance labels are separated as y values" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dvCx482shKV1" + }, + "source": [ + "x_train, y_train = load_svmlight_file(\"mq2008.train\")\n", + "x_valid, y_valid = load_svmlight_file(\"mq2008.valid\")\n", + "x_test, y_test = load_svmlight_file(\"mq2008.test\")" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z-_ZsCZS7pc8" + }, + "source": [ + "metrix recall method is relevant for binary classification ranking. However, relevance labels in data set range from 0 to 2. Relevance label signifies how relevant the document is for the query. 0 signifies no relevance and anything greater than 0 is relevant. So for our purpose, we convert all values of label 2 to 1 to signify relevance." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yEXr8TS7mKWZ" + }, + "source": [ + "y_train = np.where(y_train == 2, 1, y_train)\n", + "y_valid = np.where(y_valid == 2, 1, y_valid)\n", + "y_test = np.where(y_test == 2, 1, y_test)" + ], + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JLXvCoEj8LoZ" + }, + "source": [ + "Now datastructure for group is constructed as below" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Cw-ZHyGMkL-C" + }, + "source": [ + "group_train = []\n", + "with open(\"mq2008.train.group\", \"r\") as f:\n", + " data = f.readlines()\n", + " for line in data:\n", + " group_train.append(int(line.split(\"\\n\")[0]))\n", + "group_valid = []\n", + "with open(\"mq2008.valid.group\", \"r\") as f:\n", + " data = f.readlines()\n", + " for line in data:\n", + " group_valid.append(int(line.split(\"\\n\")[0]))\n", + "\n", + "group_test = []\n", + "with open(\"mq2008.test.group\", \"r\") as f:\n", + " data = f.readlines()\n", + " for line in data:\n", + " group_test.append(int(line.split(\"\\n\")[0]))" + ], + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DUFf3Ifk8Xp3" + }, + "source": [ + "## Model\n", + "We use XGBRanker to create a model. Note that XGBRanker requires the group information to be sent as well. Most of the hyper parameters are retained as is from xgboost documentation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tJzXmaNak8Hu", + "outputId": "ed040783-cb47-4832-c16d-5f17dc6ece66", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "source": [ + "params = {'objective': 'rank:ndcg', 'learning_rate': 0.1,\n", + " 'gamma': 0.1, 'min_child_weight': 0.1,\n", + " 'max_depth': 6, 'n_estimators': 300, 'early_stopping_rounds': 5}\n", + "model = xgb.sklearn.XGBRanker(**params)\n", + "model.fit(x_train, y_train, group_train, verbose=True,\n", + " eval_set=[(x_valid, y_valid)], eval_group=[group_valid])" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[0]\teval_0-map:0.711908\n", + "[1]\teval_0-map:0.715253\n", + "[2]\teval_0-map:0.711601\n", + "[3]\teval_0-map:0.71427\n", + "[4]\teval_0-map:0.716693\n", + "[5]\teval_0-map:0.718969\n", + "[6]\teval_0-map:0.713003\n", + "[7]\teval_0-map:0.713006\n", + "[8]\teval_0-map:0.710418\n", + "[9]\teval_0-map:0.713416\n", + "[10]\teval_0-map:0.710826\n", + "[11]\teval_0-map:0.709178\n", + "[12]\teval_0-map:0.708738\n", + "[13]\teval_0-map:0.710649\n", + "[14]\teval_0-map:0.709609\n", + "[15]\teval_0-map:0.712088\n", + "[16]\teval_0-map:0.71396\n", + "[17]\teval_0-map:0.717295\n", + "[18]\teval_0-map:0.715476\n", + "[19]\teval_0-map:0.714976\n", + "[20]\teval_0-map:0.716271\n", + "[21]\teval_0-map:0.716302\n", + "[22]\teval_0-map:0.716862\n", + "[23]\teval_0-map:0.721199\n", + "[24]\teval_0-map:0.721671\n", + "[25]\teval_0-map:0.720256\n", + "[26]\teval_0-map:0.719076\n", + "[27]\teval_0-map:0.72268\n", + "[28]\teval_0-map:0.72272\n", + "[29]\teval_0-map:0.72302\n", + "[30]\teval_0-map:0.725206\n", + "[31]\teval_0-map:0.725265\n", + "[32]\teval_0-map:0.725606\n", + "[33]\teval_0-map:0.723357\n", + "[34]\teval_0-map:0.727385\n", + "[35]\teval_0-map:0.725793\n", + "[36]\teval_0-map:0.725981\n", + "[37]\teval_0-map:0.72468\n", + "[38]\teval_0-map:0.725571\n", + "[39]\teval_0-map:0.726212\n", + "[40]\teval_0-map:0.727446\n", + "[41]\teval_0-map:0.727837\n", + "[42]\teval_0-map:0.729004\n", + "[43]\teval_0-map:0.728765\n", + "[44]\teval_0-map:0.728633\n", + "[45]\teval_0-map:0.729121\n", + "[46]\teval_0-map:0.731359\n", + "[47]\teval_0-map:0.729415\n", + "[48]\teval_0-map:0.729197\n", + "[49]\teval_0-map:0.727044\n", + "[50]\teval_0-map:0.727603\n", + "[51]\teval_0-map:0.726573\n", + "[52]\teval_0-map:0.726354\n", + "[53]\teval_0-map:0.726937\n", + "[54]\teval_0-map:0.726226\n", + "[55]\teval_0-map:0.727246\n", + "[56]\teval_0-map:0.72859\n", + "[57]\teval_0-map:0.727113\n", + "[58]\teval_0-map:0.726147\n", + "[59]\teval_0-map:0.725589\n", + "[60]\teval_0-map:0.726842\n", + "[61]\teval_0-map:0.726972\n", + "[62]\teval_0-map:0.727406\n", + "[63]\teval_0-map:0.727044\n", + "[64]\teval_0-map:0.72693\n", + "[65]\teval_0-map:0.728665\n", + "[66]\teval_0-map:0.729066\n", + "[67]\teval_0-map:0.728924\n", + "[68]\teval_0-map:0.728094\n", + "[69]\teval_0-map:0.728405\n", + "[70]\teval_0-map:0.731885\n", + "[71]\teval_0-map:0.731965\n", + "[72]\teval_0-map:0.732586\n", + "[73]\teval_0-map:0.732543\n", + "[74]\teval_0-map:0.732088\n", + "[75]\teval_0-map:0.733708\n", + "[76]\teval_0-map:0.733516\n", + "[77]\teval_0-map:0.732902\n", + "[78]\teval_0-map:0.733531\n", + "[79]\teval_0-map:0.735507\n", + "[80]\teval_0-map:0.734008\n", + "[81]\teval_0-map:0.73586\n", + "[82]\teval_0-map:0.734639\n", + "[83]\teval_0-map:0.737235\n", + "[84]\teval_0-map:0.736736\n", + "[85]\teval_0-map:0.735947\n", + "[86]\teval_0-map:0.736221\n", + "[87]\teval_0-map:0.736526\n", + "[88]\teval_0-map:0.734301\n", + "[89]\teval_0-map:0.735078\n", + "[90]\teval_0-map:0.73366\n", + "[91]\teval_0-map:0.732889\n", + "[92]\teval_0-map:0.733543\n", + "[93]\teval_0-map:0.734497\n", + "[94]\teval_0-map:0.733875\n", + "[95]\teval_0-map:0.733973\n", + "[96]\teval_0-map:0.733565\n", + "[97]\teval_0-map:0.733137\n", + "[98]\teval_0-map:0.734203\n", + "[99]\teval_0-map:0.734374\n", + "[100]\teval_0-map:0.73411\n", + "[101]\teval_0-map:0.734861\n", + "[102]\teval_0-map:0.733667\n", + "[103]\teval_0-map:0.734059\n", + "[104]\teval_0-map:0.733248\n", + "[105]\teval_0-map:0.73393\n", + "[106]\teval_0-map:0.733725\n", + "[107]\teval_0-map:0.732423\n", + "[108]\teval_0-map:0.733617\n", + "[109]\teval_0-map:0.733717\n", + "[110]\teval_0-map:0.732184\n", + "[111]\teval_0-map:0.733683\n", + "[112]\teval_0-map:0.734131\n", + "[113]\teval_0-map:0.735578\n", + "[114]\teval_0-map:0.73496\n", + "[115]\teval_0-map:0.737717\n", + "[116]\teval_0-map:0.737842\n", + "[117]\teval_0-map:0.73823\n", + "[118]\teval_0-map:0.73986\n", + "[119]\teval_0-map:0.738743\n", + "[120]\teval_0-map:0.737385\n", + "[121]\teval_0-map:0.737278\n", + "[122]\teval_0-map:0.737692\n", + "[123]\teval_0-map:0.737436\n", + "[124]\teval_0-map:0.737894\n", + "[125]\teval_0-map:0.7339\n", + "[126]\teval_0-map:0.73627\n", + "[127]\teval_0-map:0.735585\n", + "[128]\teval_0-map:0.736068\n", + "[129]\teval_0-map:0.736238\n", + "[130]\teval_0-map:0.73723\n", + "[131]\teval_0-map:0.739613\n", + "[132]\teval_0-map:0.738124\n", + "[133]\teval_0-map:0.737658\n", + "[134]\teval_0-map:0.736468\n", + "[135]\teval_0-map:0.736817\n", + "[136]\teval_0-map:0.736787\n", + "[137]\teval_0-map:0.736389\n", + "[138]\teval_0-map:0.737328\n", + "[139]\teval_0-map:0.738229\n", + "[140]\teval_0-map:0.736682\n", + "[141]\teval_0-map:0.739203\n", + "[142]\teval_0-map:0.738997\n", + "[143]\teval_0-map:0.739836\n", + "[144]\teval_0-map:0.739534\n", + "[145]\teval_0-map:0.737809\n", + "[146]\teval_0-map:0.737347\n", + "[147]\teval_0-map:0.736584\n", + "[148]\teval_0-map:0.737747\n", + "[149]\teval_0-map:0.73822\n", + "[150]\teval_0-map:0.737432\n", + "[151]\teval_0-map:0.737069\n", + "[152]\teval_0-map:0.737183\n", + "[153]\teval_0-map:0.739535\n", + "[154]\teval_0-map:0.738507\n", + "[155]\teval_0-map:0.737281\n", + "[156]\teval_0-map:0.737553\n", + "[157]\teval_0-map:0.737388\n", + "[158]\teval_0-map:0.736938\n", + "[159]\teval_0-map:0.736683\n", + "[160]\teval_0-map:0.735819\n", + "[161]\teval_0-map:0.736208\n", + "[162]\teval_0-map:0.73674\n", + "[163]\teval_0-map:0.737115\n", + "[164]\teval_0-map:0.737517\n", + "[165]\teval_0-map:0.737327\n", + "[166]\teval_0-map:0.739132\n", + "[167]\teval_0-map:0.739013\n", + "[168]\teval_0-map:0.739026\n", + "[169]\teval_0-map:0.73685\n", + "[170]\teval_0-map:0.738533\n", + "[171]\teval_0-map:0.738453\n", + "[172]\teval_0-map:0.738867\n", + "[173]\teval_0-map:0.738928\n", + "[174]\teval_0-map:0.738471\n", + "[175]\teval_0-map:0.737826\n", + "[176]\teval_0-map:0.739183\n", + "[177]\teval_0-map:0.737896\n", + "[178]\teval_0-map:0.737864\n", + "[179]\teval_0-map:0.737262\n", + "[180]\teval_0-map:0.738129\n", + "[181]\teval_0-map:0.739253\n", + "[182]\teval_0-map:0.739119\n", + "[183]\teval_0-map:0.739453\n", + "[184]\teval_0-map:0.73889\n", + "[185]\teval_0-map:0.739268\n", + "[186]\teval_0-map:0.741099\n", + "[187]\teval_0-map:0.740216\n", + "[188]\teval_0-map:0.739937\n", + "[189]\teval_0-map:0.740199\n", + "[190]\teval_0-map:0.740795\n", + "[191]\teval_0-map:0.741213\n", + "[192]\teval_0-map:0.74027\n", + "[193]\teval_0-map:0.739347\n", + "[194]\teval_0-map:0.740801\n", + "[195]\teval_0-map:0.740842\n", + "[196]\teval_0-map:0.741468\n", + "[197]\teval_0-map:0.740685\n", + "[198]\teval_0-map:0.741229\n", + "[199]\teval_0-map:0.741317\n", + "[200]\teval_0-map:0.741305\n", + "[201]\teval_0-map:0.745991\n", + "[202]\teval_0-map:0.745871\n", + "[203]\teval_0-map:0.745469\n", + "[204]\teval_0-map:0.743927\n", + "[205]\teval_0-map:0.743932\n", + "[206]\teval_0-map:0.740531\n", + "[207]\teval_0-map:0.740676\n", + "[208]\teval_0-map:0.738996\n", + "[209]\teval_0-map:0.740067\n", + "[210]\teval_0-map:0.738717\n", + "[211]\teval_0-map:0.740617\n", + "[212]\teval_0-map:0.740406\n", + "[213]\teval_0-map:0.740927\n", + "[214]\teval_0-map:0.739437\n", + "[215]\teval_0-map:0.741421\n", + "[216]\teval_0-map:0.740354\n", + "[217]\teval_0-map:0.739315\n", + "[218]\teval_0-map:0.739679\n", + "[219]\teval_0-map:0.739698\n", + "[220]\teval_0-map:0.739815\n", + "[221]\teval_0-map:0.739901\n", + "[222]\teval_0-map:0.739031\n", + "[223]\teval_0-map:0.738427\n", + "[224]\teval_0-map:0.738775\n", + "[225]\teval_0-map:0.738673\n", + "[226]\teval_0-map:0.739201\n", + "[227]\teval_0-map:0.738857\n", + "[228]\teval_0-map:0.737689\n", + "[229]\teval_0-map:0.736992\n", + "[230]\teval_0-map:0.738442\n", + "[231]\teval_0-map:0.737731\n", + "[232]\teval_0-map:0.737361\n", + "[233]\teval_0-map:0.737804\n", + "[234]\teval_0-map:0.739005\n", + "[235]\teval_0-map:0.738833\n", + "[236]\teval_0-map:0.739417\n", + "[237]\teval_0-map:0.740231\n", + "[238]\teval_0-map:0.73907\n", + "[239]\teval_0-map:0.740635\n", + "[240]\teval_0-map:0.74106\n", + "[241]\teval_0-map:0.740791\n", + "[242]\teval_0-map:0.741241\n", + "[243]\teval_0-map:0.741241\n", + "[244]\teval_0-map:0.74117\n", + "[245]\teval_0-map:0.740896\n", + "[246]\teval_0-map:0.74095\n", + "[247]\teval_0-map:0.740104\n", + "[248]\teval_0-map:0.74039\n", + "[249]\teval_0-map:0.740363\n", + "[250]\teval_0-map:0.74054\n", + "[251]\teval_0-map:0.740668\n", + "[252]\teval_0-map:0.740488\n", + "[253]\teval_0-map:0.739932\n", + "[254]\teval_0-map:0.739938\n", + "[255]\teval_0-map:0.740236\n", + "[256]\teval_0-map:0.739722\n", + "[257]\teval_0-map:0.738819\n", + "[258]\teval_0-map:0.738819\n", + "[259]\teval_0-map:0.738537\n", + "[260]\teval_0-map:0.738537\n", + "[261]\teval_0-map:0.738734\n", + "[262]\teval_0-map:0.740094\n", + "[263]\teval_0-map:0.740129\n", + "[264]\teval_0-map:0.74012\n", + "[265]\teval_0-map:0.739249\n", + "[266]\teval_0-map:0.739353\n", + "[267]\teval_0-map:0.741862\n", + "[268]\teval_0-map:0.741223\n", + "[269]\teval_0-map:0.74177\n", + "[270]\teval_0-map:0.74177\n", + "[271]\teval_0-map:0.741981\n", + "[272]\teval_0-map:0.741888\n", + "[273]\teval_0-map:0.74279\n", + "[274]\teval_0-map:0.742349\n", + "[275]\teval_0-map:0.742349\n", + "[276]\teval_0-map:0.742922\n", + "[277]\teval_0-map:0.743112\n", + "[278]\teval_0-map:0.743356\n", + "[279]\teval_0-map:0.743415\n", + "[280]\teval_0-map:0.744448\n", + "[281]\teval_0-map:0.744281\n", + "[282]\teval_0-map:0.744368\n", + "[283]\teval_0-map:0.744165\n", + "[284]\teval_0-map:0.744095\n", + "[285]\teval_0-map:0.744748\n", + "[286]\teval_0-map:0.744694\n", + "[287]\teval_0-map:0.743812\n", + "[288]\teval_0-map:0.743974\n", + "[289]\teval_0-map:0.743774\n", + "[290]\teval_0-map:0.743526\n", + "[291]\teval_0-map:0.74346\n", + "[292]\teval_0-map:0.743191\n", + "[293]\teval_0-map:0.742809\n", + "[294]\teval_0-map:0.743006\n", + "[295]\teval_0-map:0.743405\n", + "[296]\teval_0-map:0.743001\n", + "[297]\teval_0-map:0.743068\n", + "[298]\teval_0-map:0.743935\n", + "[299]\teval_0-map:0.744034\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "XGBRanker(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", + " colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=5,\n", + " gamma=0.1, learning_rate=0.1, max_delta_step=0, max_depth=6,\n", + " min_child_weight=0.1, missing=None, n_estimators=300, n_jobs=-1,\n", + " nthread=None, objective='rank:ndcg', random_state=0, reg_alpha=0,\n", + " reg_lambda=1, scale_pos_weight=1, seed=None, silent=None, subsample=1,\n", + " verbosity=1)" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 10 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dgiCvniM8tHn" + }, + "source": [ + "Interestingly XGBRanker model does not expect group data for test. Once prediction is done, it is important to look at the data by group because the predictions are relevant only within that group" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yVrUoaUnlMrS", + "outputId": "7f0aa817-0a93-46c9-9d44-02c597ac30d8", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "pred = model.predict(x_test)\n", + "\n", + "print(pred[0:group_test[0]])" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 1.487788 -4.6278214 1.1964784 1.3164926 -0.68694735 -1.1854128\n", + " -3.9207067 -4.8447275 ]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TkC2KiZw9KEh" + }, + "source": [ + "## recall_at_k metrics\n", + "Now it is time to call recall_at_k. \n", + "- Note that each group has varied set of documents so the recall can be made only by group.\n", + "- Also predictions can range from negative to positive since it reflects the rank of documents. In order to get the recall we need probabilities, hence we softmax the prediction output\n", + "- We also used k=6 when calling recall_at_k. You can call any number of relevant documents as it seems logical " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ArG_LHBRrbEB" + }, + "source": [ + "start = 0\n", + "recall = []\n", + "for group in group_test:\n", + " softmax_pred = softmax([pred[start:start+group]])\n", + " y_true = np.expand_dims(y_test[start:start+group],0)\n", + " recall.append(recall_at_k(y_true,softmax_pred,6))" + ], + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y9jGf6L29u6s" + }, + "source": [ + "Now look at the output for random 5 groups" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HhWj4Pex3Beo", + "outputId": "591049f2-f539-40d2-8183-6e5ce5653aec", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + } + }, + "source": [ + "import random\n", + "rand_num = random.sample(range(len(recall)) ,5)\n", + "for i in rand_num:\n", + " print(f'recall for group{i+1} is: {recall[i]}')" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "text": [ + "recall for group46 is: 0.6666666666666666\n", + "recall for group37 is: 1.0\n", + "recall for group2 is: 0.10526315789473684\n", + "recall for group155 is: 0.2857142857142857\n", + "recall for group16 is: 1.0\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file