test_model_baseline.py 6.55 KB
Newer Older
1
2
3
4
#!/usr/bin/env python
import numpy as np
import paths as pt
from typing import List
5
from tools import file_writer, preprocessor, data_loader
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from tools.classifiers import KnnClassifier, SvmClassifier, LrClassifier
from tools.classifiers import XgbClassifier, RfClassifier, MlpClassifier
from pathlib import Path
import csv
from utility.settings import load_settings
from utility.metrics import compute_mean, compute_std

def get_version_subtitle(version):
    if version == "NoAts":
        return "without Ats and/or Ex columns"
    elif version == "Embedded":
        return "with embeddings"
    else:
        return "with counts"

def prepare_data(X, y, settings):
    features_to_normalize = settings['features_to_normalize']
    features_to_scale = settings['features_to_scale']
    X = preprocessor.normalize_data(X, features_to_normalize)
    X = preprocessor.scale_data(X, features_to_scale)
    X = np.array(X)
    y = np.array(y)
    return X, y

def load_data_embedded(case, settings):
    if case == "Complete":
        dl = data_loader.CompleteDataLoader("complete_emb.csv", settings).load_data()
        X, y = dl.get_data()
    elif case == "Compliance":
        dl = data_loader.ComplianceDataLoader("compliance_emb.csv", settings).load_data()
        X, y = dl.get_data()
37
38
    elif case == "Alarm":
        dl = data_loader.AlarmDataLoader("alarm_emb.csv", settings).load_data()
39
40
        X, y = dl.get_data()
    else:
41
        dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
42
43
44
        X, y = dl.get_data()
    return X, y

45
def load_data_count(case, settings):
46
    if case == "Complete":
thecml's avatar
thecml committed
47
        dl = data_loader.CompleteDataLoader("complete_count.csv", settings).load_data()
48
49
        X, y = dl.get_data()
    elif case == "Compliance":
thecml's avatar
thecml committed
50
        dl = data_loader.ComplianceDataLoader("compliance_count.csv", settings).load_data()
51
        X, y = dl.get_data()
52
53
    elif case == "Alarm":
        dl = data_loader.AlarmDataLoader("alarm_count.csv", settings).load_data()
54
55
        X, y = dl.get_data()
    else:
56
        dl = data_loader.FallDataLoader("fall_count.csv", settings).load_data()
57
58
59
        X, y = dl.get_data()
    return X, y

60
def load_data_ohe(case, settings):
61
62
63
64
65
66
    if case == "Complete":
        dl = data_loader.CompleteDataLoader("complete_ohe.csv", settings).load_data()
        X, y = dl.get_data()
    elif case == "Compliance":
        dl = data_loader.ComplianceDataLoader("compliance_ohe.csv", settings).load_data()
        X, y = dl.get_data()
67
68
    elif case == "Alarm":
        dl = data_loader.AlarmDataLoader("alarm_ohe.csv", settings).load_data()
69
70
        X, y = dl.get_data()
    else:
71
        dl = data_loader.FallDataLoader("fall_ohe.csv", settings).load_data()
72
73
74
        X, y = dl.get_data()
    return X, y

75
def main():
76
    num_iter = 5
77
    clf_names = ['KNN', 'SVM', 'LR', 'XGB', 'RF', 'MLP']
thecml's avatar
thecml committed
78
    num_clfs = len(clf_names)
79
    metrics = ['accuracy', 'precision', 'recall', 'roc_auc', 'average_precision', 'f1']
80
81
    #cases = ["Complete", "Compliance", "Fall", "Risk"]
    cases = ['Alarm']
82
83
    for case in cases:
        settings = load_settings(f'{case.lower()}_emb.yaml')
84
        output_filename = f"{case} model baseline.csv"
85
86
87
88
        header = ['clf', 'version', 'accuracy_mean', 'accuracy_std',
                  'precision_mean', 'precision_std', 'recall_mean',
                  'recall_std', 'roc_auc_mean', 'roc_auc_std',
                  'pr_auc_mean', 'pr_auc_std', 'f1_mean', 'f1_std']
89
90
91
92
        with open(Path.joinpath(pt.REPORTS_DIR, output_filename), 'w',
                  encoding='UTF8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
93
        versions = ['NoAts', 'Embedded', 'Counts', 'OneHot']
94
95
96
97
98
99
100
101
        for version in versions:
            if version == 'NoAts':
                ats_cols = [f"{i}Ats" for i in range(1, 11)]
                X, y = load_data_embedded(case, settings)
                X = X.drop(ats_cols, axis=1)
            elif version == "Embedded":
                X, y = load_data_embedded(case, settings)
            elif version == "Counts":
102
103
                X, y = load_data_count(case, settings)
            else:
104
                X, y = load_data_ohe(case, settings)
105
106
                
            X, y = prepare_data(X, y, settings)
107
            results = train_clf(X, y, version, output_filename, metrics, num_iter)
108
            subtitle = get_version_subtitle(version)
109
110
            make_plots(results, metrics, num_iter, num_clfs,
                       clf_names, case, version, subtitle)
111
                
112
def train_clf(X, y, version, output_filename, metrics, num_iter):
113
114
    iteration_results = list()
    for k in range(num_iter):
115
116
117
118
119
120
        results = {'KNN': KnnClassifier(X, y).evaluate(metrics, k),
                   'SVM': SvmClassifier(X, y).evaluate(metrics, k),
                   'LR': LrClassifier(X, y).evaluate(metrics, k),
                   'XGB': XgbClassifier(X, y).evaluate(metrics, k),
                   'RF': RfClassifier(X, y).evaluate(metrics, k),
                   'MLP': MlpClassifier(X, y).evaluate(metrics, k)}
121
        iteration_results.append(results)
122
123
124
        
    for iter_result in iteration_results:
        for clf_name, result in iter_result.items():
125
126
127
            with open(Path.joinpath(pt.REPORTS_DIR, output_filename), 'a',
                      encoding='UTF8', newline='') as f:
                writer = csv.writer(f)
128
129
130
131
132
                data = [clf_name, version]
                for metric in metrics:
                    mean = compute_mean(result[metric])
                    std = compute_std(result[metric])
                    data.extend((mean, std))
133
                writer.writerow(data)
134
                
135
136
    return iteration_results

137
138
def make_plots(results: np.ndarray, metrics: List[str], num_iter: int,
               num_clfs: int, clf_names, case: str, version: str, case_subtitle):
139
    for metric in metrics:
thecml's avatar
thecml committed
140
        total_means, total_stds = list(), list()
141
        for iter_result in results:
thecml's avatar
thecml committed
142
            means, stds = np.zeros(num_clfs), np.zeros(num_clfs)
143
144
145
            for i, (_, result) in enumerate(iter_result.items()):
                means[i] = compute_mean(result[metric])
                stds[i] = compute_std(result[metric])
146
147
            total_means.append(means)
            total_stds.append(stds)
148
149
        total_means = np.stack(total_means, axis=-1)
        total_stds = np.stack(total_stds, axis=-1)
150
        file_name = f"{case} version {version} - {metric}.pdf"
thecml's avatar
thecml committed
151
        file_writer.write_cv_plot(total_means, total_stds, metric,
152
                                  num_iter, clf_names, pt.REPORTS_PLOTS_DIR,
thecml's avatar
thecml committed
153
                                  file_name, case_subtitle)
154
155
156
       
if __name__ == '__main__':
    main()