Commit 4c83bce4 authored by Alexander's avatar Alexander

Merge branch 'feature/independent-clustering' into develop

Clustering of two single layers (location and time) and storing results in MongoDb
parents ad1ef89c 9a50d2ec
...@@ -78,57 +78,57 @@ paths: ...@@ -78,57 +78,57 @@ paths:
400: 400:
description: "Invalid input" description: "Invalid input"
/clusters: /location-clusters:
get: get:
operationId: "rest.cluster.get" operationId: "rest.cluster.get_locations"
tags: tags:
- "Clusters" - "Clusters"
summary: "Get user communities per date per hour" summary: "Get user communities clustered by location"
parameters: [] parameters: []
responses: responses:
200: 200:
description: "Successful operation" description: "Successful operation"
schema: schema:
$ref: "#/definitions/UserClusterCollection" $ref: "#/definitions/LocationClusterCollection"
/clusters/cluster.png: # /clusters/cluster.png:
get: # get:
operationId: "rest.cluster.get_image" # operationId: "rest.cluster.get_image"
tags: # tags:
- "Clusters" # - "Clusters"
summary: "Get user communities per date per hour as image" # summary: "Get user communities per date per hour as image"
parameters: [] # parameters: []
produces: # produces:
- "image/png" # - "image/png"
responses: # responses:
200: # 200:
description: "Successful operation" # description: "Successful operation"
/agi/clusters: /time-clusters:
get: get:
operationId: "rest.agi_cluster.get" operationId: "rest.cluster.get_times"
tags: tags:
- "Clusters" - "Clusters"
summary: "Get user communities per date per hour from agi data" summary: "Get user communities clustered by time per hour"
parameters: [] parameters: []
responses: responses:
200: 200:
description: "Successful operation" description: "Successful operation"
schema: schema:
$ref: "#/definitions/UserClusterCollection" $ref: "#/definitions/TimeClusterCollection"
/agi/clusters/cluster.png: # /agi/clusters/cluster.png:
get: # get:
operationId: "rest.agi_cluster.get_image" # operationId: "rest.agi_cluster.get_image"
tags: # tags:
- "Clusters" # - "Clusters"
summary: "Get user communities per date per hour from agi data as image" # summary: "Get user communities per date per hour from agi data as image"
parameters: [] # parameters: []
produces: # produces:
- "image/png" # - "image/png"
responses: # responses:
200: # 200:
description: "Successful operation" # description: "Successful operation"
definitions: definitions:
Location: Location:
...@@ -152,8 +152,27 @@ definitions: ...@@ -152,8 +152,27 @@ definitions:
items: items:
$ref: "#/definitions/Location" $ref: "#/definitions/Location"
UserCluster: LocationCluster:
type: "object" type: object
properties:
id:
type: string
cluster_label:
type: number
clusters:
type: array
items:
$ref: "#/definitions/Location"
# example:
# 0: [1dc61b1a0602de0eaee9dba7eece9279c2844202, b4b31bbe5e12f55737e3a910827c81595fbca3eb]
LocationClusterCollection:
type: array
items:
$ref: "#/definitions/LocationCluster"
TimeCluster:
type: object
properties: properties:
id: id:
type: string type: string
...@@ -161,16 +180,16 @@ definitions: ...@@ -161,16 +180,16 @@ definitions:
type: string type: string
hour: hour:
type: number type: number
cluster_label:
type: number
clusters: clusters:
type: object
additionalProperties:
type: array type: array
items: items:
type: string $ref: "#/definitions/Location"
example: # example:
0: [1dc61b1a0602de0eaee9dba7eece9279c2844202, b4b31bbe5e12f55737e3a910827c81595fbca3eb] # 0: [1dc61b1a0602de0eaee9dba7eece9279c2844202, b4b31bbe5e12f55737e3a910827c81595fbca3eb]
UserClusterCollection: TimeClusterCollection:
type: array type: array
items: items:
$ref: "#/definitions/UserCluster" $ref: "#/definitions/TimeCluster"
\ No newline at end of file \ No newline at end of file
from db.entities.location import Location from db.entities.location import Location
from db.entities.popular_location import PopularLocation from db.entities.popular_location import PopularLocation
from db.entities.user_cluster import UserCluster from db.entities.cluster import LocationCluster, TimeCluster
\ No newline at end of file \ No newline at end of file
import json
from typing import List, Dict
from datetime import date, datetime
class Cluster:
def __init__(self, cluster_label: int = None, clusters: List = None):
self.cluster_label = cluster_label
self.clusters = clusters
class LocationCluster(Cluster):
def __init__(self, cluster_label: int = None, clusters: List = None,
location_dict: Dict = None, from_db=False):
super().__init__(cluster_label, clusters)
self.id = f'{self.cluster_label}'
if location_dict is not None:
self.from_serializable_dict(location_dict, from_db)
def to_serializable_dict(self, for_db=False) -> Dict:
return {
"id": self.id,
"cluster_label": self.cluster_label,
"clusters": json.dumps(self.clusters) if for_db else self.clusters
}
def from_serializable_dict(self, location_dict: Dict, from_db=False):
self.id = location_dict["id"]
self.cluster_label = location_dict["cluster_label"]
self.clusters = json.loads(location_dict["clusters"]) \
if from_db else location_dict["clusters"]
def __repr__(self):
return json.dumps(self.to_serializable_dict())
def __str__(self):
return f"LocationCluster({self.__repr__()})"
class TimeCluster(Cluster):
def __init__(self, date: date = None, hour: int = None, cluster_label: int = None, clusters: List = None,
time_dict: Dict = None, from_db=False):
super().__init__(cluster_label, clusters)
self.date = date
self.hour = hour
self.id = f'{self.date}-{self.hour}-{self.cluster_label}'
if time_dict is not None:
self.from_serializable_dict(time_dict, from_db)
def to_serializable_dict(self, for_db=False) -> Dict:
return {
"id": self.id,
"date": str(self.date),
"hour": self.hour,
"cluster_label": self.cluster_label,
"clusters": json.dumps(self.clusters) if for_db else self.clusters
}
def from_serializable_dict(self, time_dict: Dict, from_db=False):
self.id = time_dict["id"]
self.date = datetime.strptime(time_dict["date"], '%Y-%m-%d').date()
self.hour = time_dict["hour"]
self.cluster_label = time_dict["cluster_label"]
self.clusters = json.loads(time_dict["clusters"]) \
if from_db else time_dict["clusters"]
def __repr__(self):
return json.dumps(self.to_serializable_dict())
def __str__(self):
return f"TimeCluster({self.__repr__()})"
import json
class UserCluster:
def __init__(self, date, hour, clusters):
super().__init__()
self.date = date
self.hour = hour
self.clusters = clusters
self.id = f'{self.date}-{self.hour}'
def to_serializable_dict(self, for_db=False):
return {
"id": self.id,
"date": str(self.date),
"hour": self.hour,
"clusters": json.dumps(self.clusters) if for_db else self.clusters
}
def __repr__(self):
return json.dumps(self.to_serializable_dict())
def __str__(self):
return f"UserCluster({self.__repr__()})"
from __future__ import annotations
class LocationDatastore:
'''This Singelton simulates a location database'''
_instance = None
@staticmethod
def get_instance() -> LocationDatastore:
if LocationDatastore._instance == None:
LocationDatastore._instance = LocationDatastore()
return LocationDatastore._instance
def __init__(self):
if LocationDatastore._instance != None:
raise Exception("This class is a singleton!")
self.locations = []
def add(self, location):
self.locations.append(location)
def get(self):
return self.locations
\ No newline at end of file
...@@ -5,18 +5,21 @@ import json ...@@ -5,18 +5,21 @@ import json
from db.agi.agi_repository import AgiRepository from db.agi.agi_repository import AgiRepository
from db.entities import Location, UserCluster, PopularLocation from db.entities import Location, TimeCluster, PopularLocation, LocationCluster
from typing import List from typing import List
class Repository(MongoRepositoryBase): class Repository(MongoRepositoryBase):
'''This repository stores and loads locations and clusters with MongoDb.'''
def __init__(self, agi_data=False): def __init__(self):
super().__init__(netconst.COMMUNITY_DETECTION_DB_HOSTNAME, super().__init__(netconst.COMMUNITY_DETECTION_DB_HOSTNAME,
netconst.COMMUNITY_DETECTION_DB_PORT, 'communityDetectionDb') netconst.COMMUNITY_DETECTION_DB_PORT,
'communityDetectionDb')
self._location_collection = 'location_agi' if agi_data else 'location' self._location_collection = 'location'
self._cluster_collection = 'cluster_agi' if agi_data else 'cluster' self._location_cluster_collection = 'location_cluster'
self._time_cluster_collection = 'time_cluster'
self.agi_repo = AgiRepository() self.agi_repo = AgiRepository()
...@@ -31,12 +34,18 @@ class Repository(MongoRepositoryBase): ...@@ -31,12 +34,18 @@ class Repository(MongoRepositoryBase):
agi_locations = self.agi_repo.getLocations() agi_locations = self.agi_repo.getLocations()
return [Location(agi_loc) for agi_loc in agi_locations] return [Location(agi_loc) for agi_loc in agi_locations]
def add_user_cluster(self, cluster: UserCluster): def add_location_cluster(self, cluster: LocationCluster):
super().insert_entry(self._cluster_collection, cluster.to_serializable_dict(for_db=True)) super().insert_entry(self._location_cluster_collection,
cluster.to_serializable_dict(for_db=True))
def get_user_clusters(self) -> List[UserCluster]: def get_location_clusters(self) -> List[LocationCluster]:
clusters = super().get_entries(self._cluster_collection) clusters = super().get_entries(self._location_cluster_collection)
return [UserCluster(c['date'], int(c['hour']), json.loads(c['clusters'])) for c in clusters] return [LocationCluster(location_dict=c, from_db=True) for c in clusters]
def add_popular_location(self, popular_location: PopularLocation): def add_time_cluster(self, cluster: TimeCluster):
pass super().insert_entry(self._time_cluster_collection,
cluster.to_serializable_dict(for_db=True))
def get_time_clusters(self) -> List[TimeCluster]:
clusters = super().get_entries(self._time_cluster_collection)
return [TimeCluster(time_dict=c, from_db=True) for c in clusters]
import sys
import os
modules_path = '../../../modules/'
if os.path.exists(modules_path):
sys.path.insert(1, modules_path)
from db.repository import Repository
if __name__ == "__main__":
repo = Repository()
locs = repo.get_agi_locations()
for l in locs:
repo.add_location(l)
...@@ -21,7 +21,7 @@ class Clusterer: ...@@ -21,7 +21,7 @@ class Clusterer:
partition_info = labels partition_info = labels
) )
def _draw_locations(self, locations:np.ndarray=None, centroids:np.ndarray=None, partition_info=None) -> plt.Figure: def _draw_locations(self, locations:np.ndarray=None, centroids:np.ndarray=None, partition_info:List=None) -> plt.Figure:
fig = plt.Figure() fig = plt.Figure()
axis = fig.add_subplot(1, 1, 1) axis = fig.add_subplot(1, 1, 1)
...@@ -43,41 +43,57 @@ class Clusterer: ...@@ -43,41 +43,57 @@ class Clusterer:
return fig return fig
def create_labels(self, locations:List) -> List: def create_labels(self, features:np.ndarray) -> List:
if locations is None or len(locations) == 0: if features is None or len(features) == 0:
return locations # trash in trash out return features # trash in trash out
locations = self.extract_location_data(locations)
dbsc = DBSCAN(eps = self.epsilon, min_samples = self.min_points) dbsc = DBSCAN(eps = self.epsilon, min_samples = self.min_points)
dbsc = dbsc.fit(locations) dbsc = dbsc.fit(features)
labels = dbsc.labels_ labels = dbsc.labels_
return labels.tolist() return labels.tolist()
def label_locations(self, locations:List[Dict], labels:List) -> List: def extract_location_features(self, locations: List[dict]) -> np.ndarray:
if locations is None or labels is None: return np.asarray([(float(l['latitude']), float(l['longitude'])) for l in locations])
def extract_time_features(self, times: List[Dict]) -> np.ndarray:
return np.asarray([((t['timestamp']), 0) for t in times])
def label_dataset(self, dataset:List[Dict], labels:List) -> List:
if dataset is None or labels is None:
return return
if len(locations) != len(labels): if len(dataset) != len(labels):
raise ValueError("locations and labels has to have same length") raise ValueError("dataset and labels has to have same length")
for i in range(len(locations)): for i in range(len(dataset)):
locations[i]['cluster_label'] = labels[i] dataset[i]['cluster_label'] = labels[i]
def run(self, locations:List[Dict]) -> Dict[int, List[Dict]]: def group_by_clusters(self, dataset:List[Dict], labels:List) -> Dict[int, List[Dict]]:
clusters = {}
for label in labels:
clusters[label] = [ds for ds in dataset if ds['cluster_label'] == label]
return clusters
def cluster_locations(self, locations:List[Dict]) -> Dict[int, List[Dict]]:
'''Returns a dictionary with identified clusters and their locations copied from the input'''
if locations is None or len(locations) == 0: if locations is None or len(locations) == 0:
# raise Exception("locations has to contain something") # raise Exception("locations has to contain something")
return {} return {}
labels = self.create_labels(locations) features = self.extract_location_features(locations)
self.label_locations(locations, labels)
clusters = {} labels = self.create_labels(features)
for label in labels: self.label_dataset(locations, labels)
clusters[label] = [l for l in locations if l['cluster_label'] == label]
return clusters return self.group_by_clusters(locations, labels)
def extract_location_data(self, locations: List[dict]) -> np.ndarray: def cluster_times(self, times:List[Dict]) -> Dict[int, List[Dict]]:
return np.asarray([(float(l['latitude']), float(l['longitude'])) for l in locations]) '''Returns a dictionary with identified clusters and their times copied from the input'''
\ No newline at end of file features = self.extract_time_features(times)
labels = self.create_labels(features)
self.label_dataset(times, labels)
return self.group_by_clusters(times, labels)
\ No newline at end of file
import io
from flask import request, Response
from db.repository import Repository
from processing.clusterer import Clusterer
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
repo = Repository(agi_data=True)
clusterer = Clusterer()
def get():
clusters = repo.get_user_clusters()
return [c.to_serializable_dict() for c in clusters]
def get_image():
return Response(status=501)
# todo
locations = repo.getLocations()
fig = clusterer.draw_locations(locations)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
return Response(output.getvalue(), mimetype="image/png")
\ No newline at end of file
...@@ -7,11 +7,28 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas ...@@ -7,11 +7,28 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
repo = Repository() repo = Repository()
clusterer = Clusterer() clusterer = Clusterer()
def get(): def get_locations():
clusters = repo.get_user_clusters() clusters = repo.get_location_clusters()
return [c.to_serializable_dict() for c in clusters] return [c.to_serializable_dict() for c in clusters]
def get_image(): def get_times():
clusters = repo.get_time_clusters()
return [c.to_serializable_dict() for c in clusters]
def get_image_1():
return Response(status=501)
# todo
locations = repo.getLocations()
fig = clusterer.draw_locations(locations)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
return Response(output.getvalue(), mimetype="image/png")
def get_image_2():
return Response(status=501) return Response(status=501)
# todo # todo
......
...@@ -6,17 +6,17 @@ repo = Repository() ...@@ -6,17 +6,17 @@ repo = Repository()
def post(): def post():
body = request.json body = request.json
insert_location(body) _insert_location(body)
return Response(status=201) return Response(status=201)
def post_many(): def post_many():
body = request.json body = request.json
for location in body: for location in body:
insert_location(location) _insert_location(location)
return Response(status=201) return Response(status=201)
def get(): def get():
return [l.to_serializable_dict() for l in repo.get_locations()] return [l.to_serializable_dict() for l in repo.get_locations()]
def insert_location(location_data: dict): def _insert_location(location_data: dict):
repo.add_location(Location(location_data)) repo.add_location(Location(location_data))
...@@ -4,31 +4,33 @@ modules_path = '../../../modules/' ...@@ -4,31 +4,33 @@ modules_path = '../../../modules/'
if os.path.exists(modules_path): if os.path.exists(modules_path):
sys.path.insert(1, modules_path) sys.path.insert(1, modules_path)
from processing.clusterer import Clusterer from db.entities import Location, PopularLocation, LocationCluster, TimeCluster
from db.repository import Repository
from datetime import datetime, timedelta
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from db.entities import Location, PopularLocation, UserCluster from db.repository import Repository
import statistics from processing.clusterer import Clusterer
from collections import Counter
import json
DEBUG = False DEBUG = False
NR_DECIMAL_FOR_BEST_LOCATIONS = 4
# used to cluster locations of a single user to detect main location per time slice repo = Repository()
main_loc_clusterer = Clusterer()
# used to cluster the users based on their main location
user_clusterer = Clusterer()
time_slices = list(range(24)) def run_location_clustering():
user_clusterer = Clusterer()
repo = Repository(agi_data=True) all_location_traces = repo.get_locations()
cluster_result = user_clusterer.cluster_locations(
[l.to_serializable_dict() for l in all_location_traces])
def run_location_clustering(): clusters = [LocationCluster(key, value)
user_clusters: List[UserCluster] = [] for key, value in cluster_result.items()]
popular_locations: List[PopularLocation] = []
store_clusters('locations', clusters)
def run_time_clustering():
clusters: List[TimeCluster] = []
user_clusterer = Clusterer(epsilon=600) # clustered within 10 minutes
all_location_traces = repo.get_locations() all_location_traces = repo.get_locations()
...@@ -38,107 +40,39 @@ def run_location_clustering(): ...@@ -38,107 +40,39 @@ def run_location_clustering():
traces_for_cur_date = [ traces_for_cur_date = [
trace for trace in all_location_traces if trace.timestamp.date() == cur_date] trace for trace in all_location_traces if trace.timestamp.date() == cur_date]
location_counter: Dict[str, int] = {}
# for each hour of that day # for each hour of that day
for cur_hour in time_slices: for cur_hour in list(range(24)):
traces_for_time_slice = [ traces_for_time_slice = [
trace for trace in traces_for_cur_date if trace.timestamp.hour - cur_hour == 0] trace for trace in traces_for_cur_date if trace.timestamp.hour == cur_hour]
if len(traces_for_time_slice) == 0: if len(traces_for_time_slice) == 0:
continue continue
main_locations = [] # clustering per hour
cluster_result = user_clusterer.cluster_times(
# store the main location for each user [t.to_serializable_dict() for t in traces_for_time_slice])
users = {trace.user for trace in traces_for_time_slice} cur_clusters = [TimeCluster(cur_date, cur_hour, key, value)
for user in users: for key, value in cluster_result.items()]
main_loc = get_main_location_for_user(
traces_for_time_slice, user)
main_loc['user'] = user
main_locations.append(main_loc)
# cluster the main locations for all users
cluster_result = user_clusterer.run(main_locations)
clusters = {}
for key, vals in cluster_result.items():
clusters[key] = [v['user'] for v in vals]
# print(f"{cur_date} @ {cur_hour}h-{cur_hour+1}h (Group #{key}): {[v['user'] for v in vals]}")
# add the clusters for the cur_hour to the global cluster list clusters.extend(cur_clusters)
user_clusters.append(UserCluster(cur_date, cur_hour, clusters))
# add locations for cur_hour to location counter store_clusters('times', clusters)
for main_l in main_locations:
key = json.dumps({'lat': round(main_l['latitude'], NR_DECIMAL_FOR_BEST_LOCATIONS),
'long': round(main_l['longitude'], NR_DECIMAL_FOR_BEST_LOCATIONS)})
if key not in location_counter:
location_counter[key] = 0
location_counter[key] += 1
# print(f"{cur_date} @ {cur_hour}h-{cur_hour+1}h: {main_locations}")
# add the top three locations to the global popular location list def store_clusters(type: str, clusters: List):
top_locations = get_top_three_locations(location_counter)
top_locations = [json.loads(l[0]) for l in top_locations]
popular_locations.append(PopularLocation(cur_date, top_locations))
store_user_clusters(user_clusters)
store_popular_locations(popular_locations)
def get_main_location_for_user(location_traces: List[Location], user: str) -> dict:
# cluster based on locations
locations_for_user = [t for t in location_traces if t.user == user]
clusters = main_loc_clusterer.run([l.__dict__
for l in locations_for_user])
# largest cluster has most locations
max_c = {'id': -1, 'size': 0}
for cluster_key, cluster_vals in clusters.items():
if len(cluster_vals) > max_c['size']:
max_c['id'] = cluster_key
max_c['size'] = len(cluster_vals)
# calculate center of the location from the largest cluster
locations_of_largest_cluster = clusters[max_c['id']]
center = get_center_of_2d_points(locations_of_largest_cluster)
return center
def get_center_of_2d_points(points, nr_decimal_places=5) -> dict:
center = {}
center['latitude'] = round(statistics.mean(
[p['latitude'] for p in points]), nr_decimal_places)
center['longitude'] = round(statistics.mean(
[p['longitude'] for p in points]), nr_decimal_places)
return center
def get_top_three_locations(location_counts: Dict[str, int]) -> List[Tuple[str, int]]:
cnter = Counter(location_counts)
max_three = cnter.most_common(3)
return max_three
def store_user_clusters(user_clusters: List[UserCluster]):
if DEBUG: if DEBUG:
print(user_clusters) print(clusters)
return return
for c in user_clusters: if type == 'locations':
repo.add_user_cluster(c) for c in clusters:
repo.add_location_cluster(c)
def store_popular_locations(popular_locations: List[PopularLocation]):
if DEBUG:
print(popular_locations)
return
for l in popular_locations: if type == 'times':
repo.add_popular_location(l) for c in clusters:
repo.add_time_cluster(c)
if __name__ == "__main__": if __name__ == "__main__":
run_location_clustering() run_location_clustering()
run_time_clustering()
import unittest
import sys
sys.path.insert(1, './')
# python -m unittest discover -v tests
from db.entities.cluster import Cluster
from db.entities import TimeCluster, LocationCluster
from datetime import date, datetime
import json
class TestCluster(unittest.TestCase):
def test_init_Cluster(self):
c = Cluster(1, [1, 2, 3])
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
class TestLocationCluster(unittest.TestCase):
def setUp(self):
self.c = LocationCluster(1, [1, 2, 3])
def test_init_individualArguments(self):
c = LocationCluster(1, [1, 2, 3])
self.assertEqual('1', c.id)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_init_dictArgument(self):
dict_ = {'id': '123', 'cluster_label': 1, 'clusters': [1, 2, 3]}
c = LocationCluster(location_dict=dict_)
self.assertEqual('123', c.id)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_init_dictArgument_fromDb(self):
dict_ = {'id': '123', 'cluster_label': 1, 'clusters': '[1, 2, 3]'}
c = LocationCluster(location_dict=dict_, from_db=True)
self.assertEqual('123', c.id)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_to_serializable_dict_noDb(self):
c_dict = self.c.to_serializable_dict()
self.assertEqual(self.c.id, c_dict['id'])
self.assertEqual(self.c.cluster_label, c_dict['cluster_label'])
self.assertEqual(self.c.clusters, c_dict['clusters'])
def test_from_serializable_dict_noDb(self):
new_c = LocationCluster()
new_c.from_serializable_dict(self.c.to_serializable_dict())
self.assertEqual(self.c.id, new_c.id)
self.assertEqual(str(self.c), str(new_c))
def test_to_serializable_dict_db_jsonClusters(self):
c_dict = self.c.to_serializable_dict(for_db=True)
self.assertEqual(self.c.id, c_dict['id'])
self.assertEqual(self.c.cluster_label, c_dict['cluster_label'])
self.assertEqual(self.c.clusters, json.loads(c_dict['clusters']))
def test_from_serializable_dict_fromDb(self):
new_c = LocationCluster()
new_c.from_serializable_dict(
self.c.to_serializable_dict(for_db=True), from_db=True)
self.assertEqual(self.c.id, new_c.id)
self.assertEqual(str(self.c), str(new_c))
class TestTimeCluster(unittest.TestCase):
def setUp(self):
self.date_ = date(2020, 1, 1)
self.c = TimeCluster(self.date_, 14, 1, [1, 2, 3])
def test_init_individualArguments(self):
c = TimeCluster(self.date_, 14, 1, [1, 2, 3])
self.assertEqual(f'{self.date_}-14-1', c.id)
self.assertEqual(self.date_, c.date)
self.assertEqual(14, c.hour)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_init_dictArgument(self):
dict_ = {'id': '123', 'cluster_label': 1, 'clusters': [1, 2, 3],
'date': str(self.date_), 'hour': 14}
c = TimeCluster(time_dict=dict_)
self.assertEqual('123', c.id)
self.assertEqual(self.date_, c.date)
self.assertEqual(14, c.hour)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_init_dictArgument_fromDb(self):
dict_ = {'id': '123', 'cluster_label': 1, 'clusters': '[1, 2, 3]',
'date': str(self.date_), 'hour': 14}
c = TimeCluster(time_dict=dict_, from_db=True)
self.assertEqual('123', c.id)
self.assertEqual(self.date_, c.date)
self.assertEqual(14, c.hour)
self.assertEqual(1, c.cluster_label)
self.assertEqual([1, 2, 3], c.clusters)
def test_to_serializable_dict_noDb(self):
c_dict = self.c.to_serializable_dict()
self.assertEqual(self.c.id, c_dict['id'])
self.assertEqual(self.c.cluster_label, c_dict['cluster_label'])
self.assertEqual(self.c.clusters, c_dict['clusters'])
self.assertEqual(self.c.date, datetime.strptime(
c_dict['date'], '%Y-%m-%d').date())
self.assertEqual(self.c.hour, c_dict['hour'])
def test_from_serializable_dict_noDb(self):
new_c = TimeCluster()
new_c.from_serializable_dict(self.c.to_serializable_dict())
self.assertEqual(self.c.id, new_c.id)
self.assertEqual(str(self.c), str(new_c))
def test_to_serializable_dict_fromDb_jsonClusters(self):
c_dict = self.c.to_serializable_dict(for_db=True)
self.assertEqual(self.c.id, c_dict['id'])
self.assertEqual(self.c.cluster_label, c_dict['cluster_label'])
self.assertEqual(self.c.clusters, json.loads(c_dict['clusters']))
self.assertEqual(self.c.date, datetime.strptime(
c_dict['date'], '%Y-%m-%d').date())
self.assertEqual(self.c.hour, c_dict['hour'])
def test_from_serializable_dict_fromDb(self):
new_c = TimeCluster()
new_c.from_serializable_dict(
self.c.to_serializable_dict(for_db=True), from_db=True)
self.assertEqual(self.c.id, new_c.id)
self.assertEqual(str(self.c), str(new_c))
if __name__ == '__main__':
unittest.main()
...@@ -20,13 +20,15 @@ class TestClusterer(unittest.TestCase): ...@@ -20,13 +20,15 @@ class TestClusterer(unittest.TestCase):
self.assertEqual([], labels) self.assertEqual([], labels)
def test_create_labels_singleInput_singleCluster(self): def test_create_labels_singleInput_singleCluster(self):
labels = self.clusterer.create_labels([self.location(1,2)]) features = self.clusterer.extract_location_features([self.location(1,2)])
labels = self.clusterer.create_labels(features)
self.assertEqual(1, len(labels)) self.assertEqual(1, len(labels))
def test_create_labels_nearInputs_singleCluster(self): def test_create_labels_nearInputs_singleCluster(self):
locations = [self.location(1,2), self.location(2,2)] locations = [self.location(1,2), self.location(2,2)]
labels = self.clusterer.create_labels(locations) features = self.clusterer.extract_location_features(locations)
labels = self.clusterer.create_labels(features)
self.assertEqual(2, len(labels)) self.assertEqual(2, len(labels))
self.assertEqual(labels[0], labels[1]) self.assertEqual(labels[0], labels[1])
...@@ -34,36 +36,37 @@ class TestClusterer(unittest.TestCase): ...@@ -34,36 +36,37 @@ class TestClusterer(unittest.TestCase):
def test_create_labels_nearInputs_twoClusters(self): def test_create_labels_nearInputs_twoClusters(self):
locations = [self.location(1,2), self.location(2,2), self.location(20,20)] locations = [self.location(1,2), self.location(2,2), self.location(20,20)]
labels = self.clusterer.create_labels(locations) features = self.clusterer.extract_location_features(locations)
labels = self.clusterer.create_labels(features)
self.assertEqual(3, len(labels)) self.assertEqual(3, len(labels))
self.assertEqual(labels[0], labels[1]) self.assertEqual(labels[0], labels[1])
self.assertNotEqual(labels[0], labels[2]) self.assertNotEqual(labels[0], labels[2])
def test_label_locations_NoneLocations_NoException(self): def test_label_locations_NoneLocations_NoException(self):
self.clusterer.label_locations(None, []) self.clusterer.label_dataset(None, [])
def test_label_locations_NoneLabels_NoException(self): def test_label_locations_NoneLabels_NoException(self):
self.clusterer.label_locations([], None) self.clusterer.label_dataset([], None)
def test_label_locations_emptyInput_emptyOutput(self): def test_label_locations_emptyInput_emptyOutput(self):
locations = [] locations = []
self.clusterer.label_locations(locations, []) self.clusterer.label_dataset(locations, [])
self.assertEqual(0, len(locations)) self.assertEqual(0, len(locations))
def test_label_locations_diffInputLengths_ValueError_1(self): def test_label_locations_diffInputLengths_ValueError_1(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.clusterer.label_locations([], [1]) self.clusterer.label_dataset([], [1])
def test_label_locations_diffInputLengths_ValueError_2(self): def test_label_locations_diffInputLengths_ValueError_2(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.clusterer.label_locations([self.location(1,2)], []) self.clusterer.label_dataset([self.location(1,2)], [])
def test_label_locations_multInput_correctlyLabeled(self): def test_label_locations_multInput_correctlyLabeled(self):
locations = [self.location(1,2), self.location(2,2), self.location(20,20)] locations = [self.location(1,2), self.location(2,2), self.location(20,20)]
labels = [17,2,20] labels = [17,2,20]
self.clusterer.label_locations(locations, labels) self.clusterer.label_dataset(locations, labels)
self.assertEqual(3, len(locations)) self.assertEqual(3, len(locations))
self.assertHaveLabelsAsNewKey(locations, labels) self.assertHaveLabelsAsNewKey(locations, labels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment