Commit a878064e authored by Alexander Lercher's avatar Alexander Lercher

Correctly predicting with scaled metrics data

parent d94b70d7
......@@ -2,316 +2,61 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"execution_count": 52,
"source": [
"use_case = 'community-prediction-youtube-n'\r\n",
"layer_name = 'LikesLayer'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\r\n",
"from pandas import DataFrame\r\n",
"\r\n",
"df: DataFrame = pd.read_csv(f'data/{use_case}/ml_input/single_context/{layer_name}.csv', index_col=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cluster_size</th>\n",
" <th>cluster_variance</th>\n",
" <th>cluster_density</th>\n",
" <th>cluster_import1</th>\n",
" <th>cluster_import2</th>\n",
" <th>cluster_area</th>\n",
" <th>cluster_center_distance</th>\n",
" <th>time_f1</th>\n",
" <th>time_f2</th>\n",
" <th>cluster_size.1</th>\n",
" <th>...</th>\n",
" <th>cluster_size.2</th>\n",
" <th>cluster_variance.2</th>\n",
" <th>cluster_density.2</th>\n",
" <th>cluster_import1.2</th>\n",
" <th>cluster_import2.2</th>\n",
" <th>cluster_area.2</th>\n",
" <th>cluster_center_distance.2</th>\n",
" <th>time_f1.2</th>\n",
" <th>time_f2.2</th>\n",
" <th>evolution_label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>565819</th>\n",
" <td>4.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.00</td>\n",
" <td>0.000336</td>\n",
" <td>0.000168</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.992709</td>\n",
" <td>0.120537</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.992709</td>\n",
" <td>-0.120537</td>\n",
" <td>-1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>565820</th>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.00</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.935016</td>\n",
" <td>-0.354605</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.822984</td>\n",
" <td>-0.568065</td>\n",
" <td>4.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>565821</th>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.00</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.970942</td>\n",
" <td>-0.239316</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.885456</td>\n",
" <td>-0.464723</td>\n",
" <td>-1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>565822</th>\n",
" <td>4.0</td>\n",
" <td>1.089725</td>\n",
" <td>0.75</td>\n",
" <td>0.000334</td>\n",
" <td>0.000166</td>\n",
" <td>3.0</td>\n",
" <td>6.0</td>\n",
" <td>0.885456</td>\n",
" <td>-0.464723</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.748511</td>\n",
" <td>-0.663123</td>\n",
" <td>-1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>565823</th>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.00</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.748511</td>\n",
" <td>-0.663123</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.663123</td>\n",
" <td>-0.748511</td>\n",
" <td>-1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 28 columns</p>\n",
"</div>"
],
"text/plain": [
" cluster_size cluster_variance cluster_density cluster_import1 \\\n",
"565819 4.0 0.000000 0.00 0.000336 \n",
"565820 0.0 0.000000 0.00 0.000000 \n",
"565821 0.0 0.000000 0.00 0.000000 \n",
"565822 4.0 1.089725 0.75 0.000334 \n",
"565823 0.0 0.000000 0.00 0.000000 \n",
"\n",
" cluster_import2 cluster_area cluster_center_distance time_f1 \\\n",
"565819 0.000168 0.0 0.0 0.992709 \n",
"565820 0.000000 0.0 0.0 0.935016 \n",
"565821 0.000000 0.0 0.0 0.970942 \n",
"565822 0.000166 3.0 6.0 0.885456 \n",
"565823 0.000000 0.0 0.0 0.748511 \n",
"\n",
" time_f2 cluster_size.1 ... cluster_size.2 cluster_variance.2 \\\n",
"565819 0.120537 1.0 ... 0.0 0.0 \n",
"565820 -0.354605 1.0 ... 0.0 0.0 \n",
"565821 -0.239316 0.0 ... 0.0 0.0 \n",
"565822 -0.464723 1.0 ... 0.0 0.0 \n",
"565823 -0.663123 1.0 ... 0.0 0.0 \n",
"\n",
" cluster_density.2 cluster_import1.2 cluster_import2.2 \\\n",
"565819 0.0 0.0 0.0 \n",
"565820 0.0 0.0 0.0 \n",
"565821 0.0 0.0 0.0 \n",
"565822 0.0 0.0 0.0 \n",
"565823 0.0 0.0 0.0 \n",
"\n",
" cluster_area.2 cluster_center_distance.2 time_f1.2 time_f2.2 \\\n",
"565819 0.0 0.0 0.992709 -0.120537 \n",
"565820 0.0 0.0 0.822984 -0.568065 \n",
"565821 0.0 0.0 0.885456 -0.464723 \n",
"565822 0.0 0.0 0.748511 -0.663123 \n",
"565823 0.0 0.0 0.663123 -0.748511 \n",
"\n",
" evolution_label \n",
"565819 -1.0 \n",
"565820 4.0 \n",
"565821 -1.0 \n",
"565822 -1.0 \n",
"565823 -1.0 \n",
"\n",
"[5 rows x 28 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.tail()"
]
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"execution_count": 2,
"source": [
"import json\r\n",
"from entities import Cluster\r\n",
"import collections\r\n",
"import numpy as np\r\n",
"from typing import Iterable, Tuple"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"execution_count": 3,
"source": [
"N=3"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"execution_count": 53,
"source": [
"path_in = f\"data/{use_case}/cluster_metrics/{layer_name}.json\"\r\n",
"with open(path_in, 'r') as file:\r\n",
" data = [Cluster.create_from_dict(cl_d) for cl_d in json.loads(file.read())]\r\n",
"\r\n",
"data.sort(key=lambda cl: (eval(cl.cluster_id), eval(cl.time_window_id)))"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'time_window_id': '(2018, 24)', 'cluster_id': '20207', 'size': 0, 'std_dev': 0, 'scarcity': 0, 'importance1': 0, 'importance2': 0, 'range_': 0.0, 'center': [0, 0], 'global_center_distance': 0}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"source": [
"data[-1]"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"execution_count": 54,
"source": [
"cluster_map = {}\r\n",
"\r\n",
......@@ -325,78 +70,67 @@
" cluster_map[id_] = []\r\n",
"\r\n",
" cluster_map[id_].append(cluster)\r\n"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"execution_count": 55,
"source": [
"{c.cluster_id for c in data} == cluster_map.keys()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
"execution_count": 55
}
],
"source": [
"{c.cluster_id for c in data} == cluster_map.keys()"
]
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"20208"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"source": [
"len(cluster_map.keys())"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"execution_count": 38,
"source": [
"import numpy as np\r\n",
"\r\n",
"def get_cyclic_time_feature(time: int, max_time_value: int = 52) -> Tuple[float, float]:\r\n",
" return (np.sin(2*np.pi*time/max_time_value),\r\n",
" np.cos(2*np.pi*time/max_time_value))"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"execution_count": 8,
"source": [
"from typing import Tuple\r\n",
"\r\n",
"def get_metrics(cur_cluster: Cluster) -> Tuple:\r\n",
" return (cur_cluster.size, cur_cluster.std_dev, cur_cluster.scarcity, cur_cluster.importance1, cur_cluster.importance2, cur_cluster.range_, cur_cluster.global_center_distance, get_cyclic_time_feature(cur_cluster.get_time_info()))"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"execution_count": 56,
"source": [
"import pickle \r\n",
"\r\n",
......@@ -404,13 +138,25 @@
"\r\n",
"with open(f'data/{use_case}/ml_output/{method}/{layer_name}.model', 'rb') as file:\r\n",
" svc = pickle.load(file)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"execution_count": 63,
"source": [
"import pickle \r\n",
"\r\n",
"with open(f'data/{use_case}/ml_output/{method}/{layer_name}_scaler.model', 'rb') as file:\r\n",
" scaler = pickle.load(file)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"def flatten_metrics_datapoint(datapoint: list) -> Tuple['X', np.array]:\r\n",
" '''\r\n",
......@@ -426,13 +172,13 @@
"\r\n",
" # flat_list.append(datapoint[-1]) # y\r\n",
" return np.asarray(flat_list)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"execution_count": 11,
"source": [
"def increase_time_window(time_window_id: str):\r\n",
" tuple_ = eval(time_window_id)\r\n",
......@@ -443,32 +189,180 @@
" else:\r\n",
" # next week\r\n",
" return str((tuple_[0], tuple_[1]+1))\r\n"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"execution_count": 58,
"source": [
"from entities import PredictionResult\r\n",
"from db.dao import PredictionResult\r\n",
"\r\n",
"prediction_results = []\r\n",
"# prediction_results = []\r\n",
"prediction_cluster_ids = []\r\n",
"prediction_time_windows = []\r\n",
"prediction_metrics = []\r\n",
"\r\n",
"for cluster_id, time_windows in cluster_map.items():\r\n",
" v = [get_metrics(c) for c in time_windows[-N:]] # metrics for last N time windows\r\n",
" v_flattened = flatten_metrics_datapoint(v)\r\n",
" v_flattened = v_flattened.reshape(1, v_flattened.shape[0]) # reshape for ML with only 1 pred value\r\n",
" res = PredictionResult(use_case, use_case, method, layer_name, None, cluster_id, increase_time_window(time_windows[-1].time_window_id), svc.predict(v_flattened)[0])\r\n",
" prediction_results.append(res)"
]
"\r\n",
" prediction_cluster_ids.append(cluster_id)\r\n",
" prediction_time_windows.append(increase_time_window(time_windows[-1].time_window_id))\r\n",
" prediction_metrics.append(v_flattened)\r\n",
"\r\n",
"\r\n",
" # v_flattened = v_flattened.reshape(1, v_flattened.shape[0]) # reshape for ML with only 1 pred value\r\n",
" # res = PredictionResult(use_case, use_case, method, layer_name, None, cluster_id, increase_time_window(time_windows[-1].time_window_id), svc.predict(v_flattened)[0])\r\n",
" # prediction_results.append(res)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"execution_count": 64,
"source": [
"scaler.transform(prediction_metrics[0].reshape(1,27))"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[-0.2525847 , -0.00725354, -0.00748744, -0.26150883, -0.61179695,\n",
" -0.00699078, -0.0156031 , 0.10230883, -1.49959068, -0.25198809,\n",
" -0.00721248, -0.00740694, -0.2559145 , -0.6125857 , -0.0069614 ,\n",
" -0.01582086, -0.22871208, -1.567934 , -0.25144835, -0.00729236,\n",
" -0.00753175, -0.25448947, -0.6134931 , -0.00698498, -0.01589221,\n",
" -0.63013244, -1.62002196]])"
]
},
"metadata": {},
"execution_count": 64
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 65,
"source": [
"prediction_results = svc.predict(scaler.transform(np.array(prediction_metrics)))"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 67,
"source": [
"prediction_metrics[15]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0.46472317, -0.88545603, 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0.35460489, -0.93501624, 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0.23931566, -0.97094182])"
]
},
"metadata": {},
"execution_count": 67
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 29,
"source": [
"dataa = np.array(prediction_metrics)\r\n",
"svc.predict(dataa[3].reshape(1, 27))"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([3.])"
]
},
"metadata": {},
"execution_count": 29
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 68,
"source": [
"predictions = []\r\n",
"for i in range(len(prediction_cluster_ids)):\r\n",
" predictions.append(\r\n",
" PredictionResult(use_case, use_case, method, layer_name, None, prediction_cluster_ids[i], prediction_time_windows[i], prediction_results[i])\r\n",
" )"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 74,
"source": [
"list(zip(np.unique(prediction_results, return_counts=True)))"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[(array([0., 1., 2., 3., 4.]),),\n",
" (array([ 2740, 596, 1429, 1324, 14119], dtype=int64),)]"
]
},
"metadata": {},
"execution_count": 74
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 70,
"source": [
"prediction_results"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([4., 4., 0., ..., 0., 0., 0.])"
]
},
"metadata": {},
"execution_count": 70
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"source": [
"[r.__dict__ for r in predictions[:10]]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[{'use_case': 'community-prediction-youtube-n',\n",
......@@ -478,7 +372,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '0',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -486,7 +380,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '1',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -502,7 +396,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '3',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -510,7 +404,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '4',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -518,7 +412,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '5',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -526,7 +420,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '6',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -542,7 +436,7 @@
" 'reference_layer': None,\n",
" 'cluster_id': '8',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0},\n",
" 'prediction': 3.0},\n",
" {'use_case': 'community-prediction-youtube-n',\n",
" 'table': 'community-prediction-youtube-n',\n",
" 'method': 'single_context',\n",
......@@ -550,41 +444,30 @@
" 'reference_layer': None,\n",
" 'cluster_id': '9',\n",
" 'time_window': '(2018, 25)',\n",
" 'prediction': 2.0}]"
" 'prediction': 3.0}]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
"execution_count": 15
}
],
"source": [
"[r.__dict__ for r in prediction_results[:10]]"
]
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
}
],
"source": []
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"interpreter": {
"hash": "6f758d9e9b2866087a1d464f700475727f47c3870deef6e7815ca445f120e6ad"
"hash": "f4b37965f8116f61e214526431d03f7da6e57badb249bab76499e8551fed5453"
},
"kernelspec": {
"display_name": "Python 3.7.6 64-bit ('venv': venv)",
"name": "python3"
"name": "python3",
"display_name": "Python 3.7.8 64-bit ('venv': venv)"
},
"language_info": {
"codemirror_mode": {
......@@ -596,7 +479,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.8"
},
"orig_nbformat": 4
},
......
......@@ -17,7 +17,8 @@ from typing import Dict
from typing import Tuple
def get_metrics(cur_cluster: Cluster) -> Tuple:
return (cur_cluster.size, cur_cluster.std_dev, cur_cluster.scarcity, cur_cluster.importance1, cur_cluster.importance2, cur_cluster.range_, cur_cluster.global_center_distance, get_cyclic_time_feature(cur_cluster.get_time_info()))
return (cur_cluster.size, cur_cluster.std_dev, cur_cluster.scarcity, cur_cluster.importance1, cur_cluster.importance2,
cur_cluster.range_, cur_cluster.global_center_distance, get_cyclic_time_feature(cur_cluster.get_time_info()))
####################
import pickle
#####################
......@@ -53,9 +54,8 @@ repo = Repository()
def run_prediction(use_case: str):
for layer in repo.get_layers_for_use_case(use_case):
layer_name = layer.layer_name
print(f"Predicting {method} for {use_case}//{layer_name}")
################
df: DataFrame = pd.read_csv(f'data/{use_case}/ml_input/single_context/{layer_name}.csv', index_col=0)
#################
path_in = f"data/{use_case}/cluster_metrics/{layer_name}.json"
with open(path_in, 'r') as file:
......@@ -75,12 +75,27 @@ def run_prediction(use_case: str):
####################
with open(f'data/{use_case}/ml_output/{method}/{layer_name}.model', 'rb') as file:
svc = pickle.load(file)
####################
with open(f'data/{use_case}/ml_output/{method}/{layer_name}_scaler.model', 'rb') as file:
scaler = pickle.load(file)
#####################
# store id, future time window, and flattened metrics to combine the latter during prediction
prediction_cluster_ids = []
prediction_time_windows = []
prediction_metrics = []
for cluster_id, time_windows in cluster_map.items():
v = [get_metrics(c) for c in time_windows[-N:]] # metrics for last N time windows
v_flattened = flatten_metrics_datapoint(v)
v_flattened = v_flattened.reshape(1, v_flattened.shape[0]) # reshape for ML with only 1 pred value
res = PredictionResult(use_case, use_case, method, layer_name, None, cluster_id, increase_time_window(time_windows[-1].time_window_id), svc.predict(v_flattened)[0])
repo.add_prediction_result(res)
#####################
prediction_cluster_ids.append(cluster_id)
prediction_time_windows.append(increase_time_window(time_windows[-1].time_window_id))
prediction_metrics.append(v_flattened)
# predict all at once for speedup
prediction_results = svc.predict(scaler.transform(np.array(prediction_metrics)))
print(np.unique(prediction_results, return_counts=True))
for i in range(len(prediction_cluster_ids)):
res = PredictionResult(use_case, use_case, method, layer_name, None, prediction_cluster_ids[i], prediction_time_windows[i], prediction_results[i])
repo.add_prediction_result(res)
......@@ -8,10 +8,10 @@ approach = 'single_context'
import pickle
from pathlib import Path
def export_model(model, use_case, layer_name):
def export_model(model, use_case, layer_name, scaler=False):
fpath = f'data/{use_case}/ml_output/{approach}'
Path(fpath).mkdir(parents=True, exist_ok=True)
with open(f'{fpath}/{layer_name}.model', 'wb') as f:
with open(f'{fpath}/{layer_name}{"_scaler" if scaler else ""}.model', 'wb') as f:
pickle.dump(model, f)
#####################
from sklearn.ensemble import RandomForestClassifier
......@@ -45,11 +45,13 @@ def run_training(use_case):
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
train_X = scaler.fit_transform(training)[:,:-1] # all except y
train_X = scaler.fit_transform(training[training.columns[:-1]]) # all except y
train_Y = training[training.columns[-1]]
test_X = scaler.transform(testing)[:,:-1] # all except y
test_X = scaler.transform(testing[testing.columns[:-1]]) # all except y
test_Y = testing[testing.columns[-1]]
export_model(scaler, use_case, layer_name, scaler=True)
########################
from processing import DataSampler
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment