-
-
Notifications
You must be signed in to change notification settings - Fork 186
Expand file tree
/
Copy pathcrli_clustering_example.py
More file actions
58 lines (46 loc) · 1.64 KB
/
Copy pathcrli_clustering_example.py
File metadata and controls
58 lines (46 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
A minimalist, standalone example of the PyPOTS CRLI model for time-series clustering.
This script is auto-generated by extracting hyperparameters from the test code.
"""
from benchpots.datasets import preprocess_random_walk
from pypots.clustering import CRLI
from pypots.nn.functional import calc_external_cluster_validation_metrics
def main():
n_steps = 48
n_features = 35
n_clusters = 5
# 1. Generate a random walk time-series dataset
dataset = preprocess_random_walk(
n_steps=n_steps, n_features=n_features, n_classes=n_clusters, n_samples_each_class=40, missing_rate=0.1
)
# 2. Extract training and test sets
train_set = {"X": dataset["train_X"]}
val_set = {"X": dataset["val_X"]}
test_set = {"X": dataset["test_X"]}
test_y_true = dataset["test_y"]
# 3. Initialize the model
model = CRLI(
n_steps=n_steps,
n_features=n_features,
n_clusters=n_clusters,
n_generator_layers=2,
rnn_hidden_size=32,
rnn_cell_type="LSTM",
epochs=2,
device="cpu",
)
# 4. Train the model
print("🚀 Training the CRLI clustering model...")
model.fit(train_set, val_set)
# 5. Predict clusters
print("🔮 Predicting clusters for the test set...")
results = model.predict(test_set)
clusters = results["clustering"]
# 6. Evaluate
metrics = calc_external_cluster_validation_metrics(clusters, test_y_true)
print(
f"✅ CRLI clustering external metrics: \n"
f"CR: {metrics['rand_index']:.4f}, Purity: {metrics['cluster_purity']:.4f}, NMI: {metrics['nmi']:.4f}"
)
if __name__ == "__main__":
main()