predict.ipynb 15.9 KB
Newer Older
1 2 3 4
{
 "cells": [
  {
   "cell_type": "code",
5
   "execution_count": 1,
6 7
   "source": [
    "use_case = 'community-prediction-youtube-n'\r\n",
8 9
    "layer_name = 'LikesLayer'\r\n",
    "reference_layer_name = 'ViewsLayer'"
10
   ],
11 12
   "outputs": [],
   "metadata": {}
13 14 15
  },
  {
   "cell_type": "code",
16
   "execution_count": 5,
17 18 19 20 21
   "source": [
    "import json\r\n",
    "from entities import Cluster\r\n",
    "import collections\r\n",
    "import numpy as np\r\n",
22
    "from typing import Iterable, Tuple, List, Dict, Any"
23 24 25
   ],
   "outputs": [],
   "metadata": {}
26 27 28
  },
  {
   "cell_type": "code",
29
   "execution_count": 3,
30
   "source": [
31
    "N=2"
32 33 34
   ],
   "outputs": [],
   "metadata": {}
35 36 37
  },
  {
   "cell_type": "code",
38
   "execution_count": 6,
39
   "source": [
40
    "from entities import Layer, Cluster\r\n",
41
    "\r\n",
42 43 44 45 46 47 48 49
    "with open(f'data/{use_case}/cluster_metrics/{layer_name}.json') as file:\r\n",
    "    cluster_metrics: List[Cluster] = [Cluster.create_from_dict(e) for e in json.loads(file.read())]\r\n",
    "    cluster_ids = {c.cluster_id for c in cluster_metrics}\r\n",
    "    cluster_metrics: Dict[Any, Cluster] = {(c.time_window_id, c.cluster_id): c for c in cluster_metrics}\r\n",
    "        \r\n",
    "with open(f'data/{use_case}/layer_metrics/{reference_layer_name}.json') as file:\r\n",
    "    layer_metrics: List[Layer] = [Layer.create_from_dict(e) for e in json.loads(file.read())]\r\n",
    "    layer_metrics: Dict[Any, Layer] = {l.time_window_id: l for l in layer_metrics}\r\n"
50 51 52
   ],
   "outputs": [],
   "metadata": {}
53 54 55
  },
  {
   "cell_type": "code",
56
   "execution_count": 11,
57
   "source": [
58 59 60
    "# load the time keys chronologically\r\n",
    "ordered_time_keys = list(layer_metrics.keys())\r\n",
    "ordered_time_keys.sort(key=lambda x: eval(x))"
61 62 63
   ],
   "outputs": [],
   "metadata": {}
64 65 66
  },
  {
   "cell_type": "code",
67
   "execution_count": 13,
68
   "source": [
69 70
    "ordered_time_keys = ordered_time_keys[-N:]\r\n",
    "ordered_time_keys"
71
   ],
72 73
   "outputs": [
    {
74
     "output_type": "execute_result",
75 76
     "data": {
      "text/plain": [
77
       "['(2018, 23)', '(2018, 24)']"
78 79 80
      ]
     },
     "metadata": {},
81
     "execution_count": 13
82 83
    }
   ],
84
   "metadata": {}
85 86 87
  },
  {
   "cell_type": "code",
88
   "execution_count": 19,
89
   "source": [
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    "import numpy as np\r\n",
    "\r\n",
    "def get_cyclic_time_feature(time: int, max_time_value: int = 52) -> Tuple[float, float]:\r\n",
    "    return (np.sin(2*np.pi*time/max_time_value),\r\n",
    "            np.cos(2*np.pi*time/max_time_value))\r\n",
    "\r\n",
    "def get_cyclic_time_feature_from_time_window(time: str) -> Tuple[float, float]:\r\n",
    "    return get_cyclic_time_feature(int(time.replace('(', '').replace(')', '').split(',')[1]))\r\n",
    "\r\n",
    "def get_layer_metrics(layer: Layer) -> Iterable:\r\n",
    "    res = [layer.n_nodes, layer.n_clusters, layer.entropy]\r\n",
    "    res += [layer.cluster_size_agg_metrics[k] for k in ['min', 'max', 'avg', 'sum']]\r\n",
    "    res += [layer.cluster_relative_size_agg_metrics[k] for k in ['min', 'max', 'avg', 'sum']]\r\n",
    "    res += [layer.cluster_center_distance_agg_metrics[k] for k in ['min', 'max', 'avg', 'sum']]\r\n",
    "    res.append(get_cyclic_time_feature_from_time_window(layer.time_window_id))\r\n",
    "    return res"
106 107 108
   ],
   "outputs": [],
   "metadata": {}
109 110 111
  },
  {
   "cell_type": "code",
112
   "execution_count": 25,
113
   "source": [
114
    "prediction_metrics_raw = []"
115 116 117
   ],
   "outputs": [],
   "metadata": {}
118 119 120
  },
  {
   "cell_type": "code",
121
   "execution_count": 26,
122
   "source": [
123 124
    "current_layer_metric = layer_metrics[ordered_time_keys[1]]\r\n",
    "prev_layer_metric = layer_metrics[ordered_time_keys[0]]\r\n",
125
    "\r\n",
126 127 128 129 130 131
    "current_layer_metric_tuple = get_layer_metrics(current_layer_metric)\r\n",
    "prev_layer_metric_tuple = get_layer_metrics(prev_layer_metric)\r\n",
    "\r\n",
    "for cluster_id in cluster_ids:\r\n",
    "    # yield each combination of reference layer metrics to clusters\r\n",
    "    prediction_metrics_raw.append([prev_layer_metric_tuple, current_layer_metric_tuple, int(cluster_id)])"
132 133 134
   ],
   "outputs": [],
   "metadata": {}
135 136 137
  },
  {
   "cell_type": "code",
138
   "execution_count": 38,
139
   "source": [
140 141
    "method = 'cross_context'\r\n",
    "\r\n",
142 143
    "import pickle \r\n",
    "\r\n",
144 145
    "with open(f'data/{use_case}/ml_output/{method}/{layer_name}_{reference_layer_name}.model', 'rb') as file:\r\n",
    "    svc = pickle.load(file)\r\n",
146
    "\r\n",
147 148
    "with open(f'data/{use_case}/ml_output/{method}/{layer_name}_{reference_layer_name}_scaler.model', 'rb') as file:\r\n",
    "    scaler = pickle.load(file)"
149 150 151
   ],
   "outputs": [],
   "metadata": {}
152 153 154
  },
  {
   "cell_type": "code",
155
   "execution_count": 38,
156
   "source": [
157
    "import numpy as np\r\n",
158
    "\r\n",
159 160 161
    "def get_cyclic_time_feature(time: int, max_time_value: int = 52) -> Tuple[float, float]:\r\n",
    "    return (np.sin(2*np.pi*time/max_time_value),\r\n",
    "            np.cos(2*np.pi*time/max_time_value))"
162
   ],
163
   "outputs": [],
164 165 166 167
   "metadata": {}
  },
  {
   "cell_type": "code",
168
   "execution_count": 30,
169
   "source": [
170 171 172
    "import numpy as np\r\n",
    "\r\n",
    "def flatten_layer_metrics_datapoint(datapoint: list) -> Tuple['X', np.array]:\r\n",
173
    "    '''\r\n",
174 175 176 177 178
    "    Flattens a single layer metrics data point in the form:\r\n",
    "    [(n_nodes, n_clusters, entropy,\r\n",
    "     (relative_cluster_size)^M, (distance_from_global_centers)^M, \r\n",
    "     (time1, time2))^N, \r\n",
    "     cluster_number, evolution_label]\r\n",
179 180 181 182
    "    to:\r\n",
    "    (X, y: np.array)\r\n",
    "    '''\r\n",
    "    flat_list = []\r\n",
183 184 185 186 187
    "    for layer_metric_tuple in datapoint[:-1]: # for all x\r\n",
    "        flat_list.extend(layer_metric_tuple[0:-1]) # everything before time\r\n",
    "        flat_list.extend(layer_metric_tuple[-1]) # time1/2\r\n",
    "\r\n",
    "    flat_list.append(datapoint[-1]) # cluster num\r\n",
188 189
    "\r\n",
    "    return np.asarray(flat_list)"
190 191 192
   ],
   "outputs": [],
   "metadata": {}
193 194 195
  },
  {
   "cell_type": "code",
196
   "execution_count": 31,
197 198 199 200 201 202 203 204 205 206
   "source": [
    "def increase_time_window(time_window_id: str):\r\n",
    "    tuple_ = eval(time_window_id)\r\n",
    "    \r\n",
    "    if tuple_[1] == 52:\r\n",
    "        # 1st week next year\r\n",
    "        return (tuple_[0]+1 , 1)\r\n",
    "    else:\r\n",
    "        # next week\r\n",
    "        return str((tuple_[0], tuple_[1]+1))\r\n"
207 208 209
   ],
   "outputs": [],
   "metadata": {}
210 211 212
  },
  {
   "cell_type": "code",
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
   "execution_count": 33,
   "source": [],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "895\n",
      "[ 1.01800000e+04  6.94600000e+03  1.25669044e+01  1.00000000e+00\n",
      "  1.20000000e+01  1.46559171e+00  1.01800000e+04  9.82318271e-05\n",
      "  1.17878193e-03  1.43967751e-04  1.00000000e+00  0.00000000e+00\n",
      "  2.37254283e+06  1.14923227e+03  7.98256735e+06  3.54604887e-01\n",
      " -9.35016243e-01  4.35300000e+03  3.25600000e+03  1.15021768e+01\n",
      "  1.00000000e+00  1.00000000e+01  1.33691646e+00  4.35300000e+03\n",
      "  2.29726625e-04  2.29726625e-03  3.07125307e-04  1.00000000e+00\n",
      "  0.00000000e+00  2.36405615e+05  3.69147185e+02  1.20194323e+06\n",
      "  2.39315664e-01 -9.70941817e-01  8.95000000e+02]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 34,
238
   "source": [
239
    "from db.dao import PredictionResult\r\n",
240
    "\r\n",
241
    "prediction_cluster_ids = []\r\n",
242
    "prediction_time_window = increase_time_window(ordered_time_keys[1])\r\n",
243
    "prediction_metrics = []\r\n",
244 245 246
    "    \r\n",
    "for pred in  prediction_metrics_raw:\r\n",
    "    cluster_id = pred[-1]\r\n",
247 248
    "    prediction_cluster_ids.append(cluster_id)\r\n",
    "\r\n",
249 250 251
    "    flat_ = flatten_layer_metrics_datapoint(pred)\r\n",
    "    prediction_metrics.append(flat_)\r\n",
    "    "
252 253 254
   ],
   "outputs": [],
   "metadata": {}
255 256 257
  },
  {
   "cell_type": "code",
258
   "execution_count": 41,
259
   "source": [
260 261 262 263 264 265 266 267 268 269
    "prediction_results = svc.predict(scaler.transform(np.array(prediction_metrics)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "source": [
    "prediction_metrics[15]"
270 271 272 273 274 275
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
276 277 278 279 280 281 282 283 284
       "array([ 1.01800000e+04,  6.94600000e+03,  1.25669044e+01,  1.00000000e+00,\n",
       "        1.20000000e+01,  1.46559171e+00,  1.01800000e+04,  9.82318271e-05,\n",
       "        1.17878193e-03,  1.43967751e-04,  1.00000000e+00,  0.00000000e+00,\n",
       "        2.37254283e+06,  1.14923227e+03,  7.98256735e+06,  3.54604887e-01,\n",
       "       -9.35016243e-01,  4.35300000e+03,  3.25600000e+03,  1.15021768e+01,\n",
       "        1.00000000e+00,  1.00000000e+01,  1.33691646e+00,  4.35300000e+03,\n",
       "        2.29726625e-04,  2.29726625e-03,  3.07125307e-04,  1.00000000e+00,\n",
       "        0.00000000e+00,  2.36405615e+05,  3.69147185e+02,  1.20194323e+06,\n",
       "        2.39315664e-01, -9.70941817e-01,  4.36000000e+03])"
285 286 287
      ]
     },
     "metadata": {},
288
     "execution_count": 42
289 290 291 292 293 294
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
295
   "execution_count": 29,
296
   "source": [
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
    "dataa = np.array(prediction_metrics)\r\n",
    "svc.predict(dataa[3].reshape(1, 27))"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([3.])"
      ]
     },
     "metadata": {},
     "execution_count": 29
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "source": [
    "predictions = []\r\n",
    "for i in range(len(prediction_cluster_ids)):\r\n",
    "    predictions.append(\r\n",
    "        PredictionResult(use_case, use_case, method, layer_name, None, prediction_cluster_ids[i], prediction_time_window, prediction_results[i])\r\n",
    "    )"
323 324 325 326 327 328
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
329
   "execution_count": 45,
330
   "source": [
331
    "list(zip(np.unique(prediction_results, return_counts=True)))"
332 333 334 335 336 337
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
338 339
       "[(array([0., 1., 2., 3.]),),\n",
       " (array([ 5335,  1511,   355, 13007], dtype=int64),)]"
340 341 342
      ]
     },
     "metadata": {},
343
     "execution_count": 45
344 345 346 347 348 349
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
350
   "execution_count": 46,
351
   "source": [
352
    "prediction_results"
353 354 355 356 357 358
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
359
       "array([3., 0., 0., ..., 0., 3., 3.])"
360 361 362
      ]
     },
     "metadata": {},
363
     "execution_count": 46
364 365 366 367 368 369
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
370
   "execution_count": 51,
371
   "source": [
372 373 374 375 376 377 378 379 380 381 382 383 384 385
    "time = '(2019, 45)'\r\n",
    "int(time.replace('(', '').replace(')', '').split(',')[1])"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "45"
      ]
     },
     "metadata": {},
     "execution_count": 51
    }
386 387 388 389 390
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
391
   "execution_count": 52,
392
   "source": [
393
    "eval(time)[1]"
394 395 396 397 398 399
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
400
       "45"
401 402 403
      ]
     },
     "metadata": {},
404
     "execution_count": 52
405 406 407 408 409 410
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
411
   "execution_count": 53,
412
   "source": [
413
    "int(time.split(',')[1].strip()[:-1])"
414 415 416 417 418 419
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
420
       "45"
421 422 423
      ]
     },
     "metadata": {},
424
     "execution_count": 53
425 426 427 428 429 430
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
431
   "execution_count": 47,
432 433 434
   "source": [
    "[r.__dict__ for r in predictions[:10]]"
   ],
435 436
   "outputs": [
    {
437
     "output_type": "execute_result",
438 439 440 441
     "data": {
      "text/plain": [
       "[{'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
442
       "  'method': 'cross_context',\n",
443 444
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
445
       "  'cluster_id': 895,\n",
446
       "  'time_window': '(2018, 25)',\n",
447
       "  'prediction': 3.0},\n",
448 449
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
450
       "  'method': 'cross_context',\n",
451 452
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
453
       "  'cluster_id': 8947,\n",
454
       "  'time_window': '(2018, 25)',\n",
455
       "  'prediction': 0.0},\n",
456 457
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
458
       "  'method': 'cross_context',\n",
459 460
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
461
       "  'cluster_id': 10464,\n",
462
       "  'time_window': '(2018, 25)',\n",
463
       "  'prediction': 0.0},\n",
464 465
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
466
       "  'method': 'cross_context',\n",
467 468
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
469
       "  'cluster_id': 14671,\n",
470
       "  'time_window': '(2018, 25)',\n",
471
       "  'prediction': 3.0},\n",
472 473
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
474
       "  'method': 'cross_context',\n",
475 476
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
477
       "  'cluster_id': 18000,\n",
478
       "  'time_window': '(2018, 25)',\n",
479
       "  'prediction': 3.0},\n",
480 481
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
482
       "  'method': 'cross_context',\n",
483 484
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
485
       "  'cluster_id': 17895,\n",
486
       "  'time_window': '(2018, 25)',\n",
487
       "  'prediction': 2.0},\n",
488 489
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
490
       "  'method': 'cross_context',\n",
491 492
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
493
       "  'cluster_id': 1234,\n",
494
       "  'time_window': '(2018, 25)',\n",
495
       "  'prediction': 3.0},\n",
496 497
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
498
       "  'method': 'cross_context',\n",
499 500
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
501
       "  'cluster_id': 16236,\n",
502 503 504 505
       "  'time_window': '(2018, 25)',\n",
       "  'prediction': 3.0},\n",
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
506
       "  'method': 'cross_context',\n",
507 508
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
509
       "  'cluster_id': 1995,\n",
510
       "  'time_window': '(2018, 25)',\n",
511
       "  'prediction': 3.0},\n",
512 513
       " {'use_case': 'community-prediction-youtube-n',\n",
       "  'table': 'community-prediction-youtube-n',\n",
514
       "  'method': 'cross_context',\n",
515 516
       "  'layer': 'LikesLayer',\n",
       "  'reference_layer': None,\n",
517
       "  'cluster_id': 5161,\n",
518
       "  'time_window': '(2018, 25)',\n",
519
       "  'prediction': 0.0}]"
520 521 522
      ]
     },
     "metadata": {},
523
     "execution_count": 47
524 525
    }
   ],
526
   "metadata": {}
527 528 529
  },
  {
   "cell_type": "code",
530 531 532 533
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
534 535 536 537
  }
 ],
 "metadata": {
  "interpreter": {
538
   "hash": "f4b37965f8116f61e214526431d03f7da6e57badb249bab76499e8551fed5453"
539 540
  },
  "kernelspec": {
541 542
   "name": "python3",
   "display_name": "Python 3.7.8 64-bit ('venv': venv)"
543 544 545 546 547 548 549 550 551 552 553
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
554
   "version": "3.7.8"
555 556 557 558 559 560
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}