-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkmeans.py
More file actions
43 lines (33 loc) · 1.57 KB
/
kmeans.py
File metadata and controls
43 lines (33 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Union, Dict, List
from sklearn.cluster import KMeans
import numpy as np
class BlockKMeans:
MIN_LENGTH = 5
RANDOM_STATE = 42
def __init__(self):
self.kmeans_list: Dict[str, KMeans] = {}
self.kmeans_values: Dict[str, list] = {}
@staticmethod
def format_data(data: list) -> np.ndarray:
return np.array(data).reshape((-1, 1))
def train(self, values: Union[list, np.ndarray], key: str, maximum: int = int(1e100)):
if isinstance(values, list):
self.kmeans_values.setdefault(key, []).extend(values)
elif isinstance(values, np.float64):
self.kmeans_values.setdefault(key, []).append(values)
else:
raise TypeError(f"Values has invalid type ({type(values)}), must be list or np.float64")
if len(self.kmeans_values[key]) <= self.MIN_LENGTH:
return
if key not in self.kmeans_list:
self.kmeans_list[key] = KMeans(
n_clusters=2,
init=[[np.min(self.kmeans_values[key])], [np.max(self.kmeans_values[key])]],
n_init=1,
random_state=self.RANDOM_STATE)
self.kmeans_values[key] = self.kmeans_values[key][-maximum:]
self.kmeans_list[key].fit(self.format_data(self.kmeans_values[key]))
def predict(self, values: Union[list, int], key: str) -> List[int]:
return self.kmeans_list[key].predict(self.format_data(values)) if key in self.kmeans_list else ["UNK"]
def get_cluster_centers(self, key: str) -> list:
return self.kmeans_list[key].cluster_centers_