Commit 2ca5a0c8 authored by Christian Marius Lillelund's avatar Christian Marius Lillelund
Browse files

added some test scripts

parent f04df37c
Pipeline #48323 passed with stage
in 3 minutes and 50 seconds
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import datetime as dt
from tools import preprocessor, data_loader
import config as cfg
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
tf.get_logger().setLevel('ERROR')
# Set dataset
filename = "complete.csv"
converters = {str(i)+'Ats':str for i in range(1, cfg.ATS_RESOLUTION+1)}
# Create data loader, prepare data
dl = data_loader.CompleteDataLoader(file_name=filename, converters=converters).load_data()
X, y = dl.get_data()
df = pd.concat([X, y], axis=1)
# Add age feature
df['Age'] = df['BirthYear'].apply(lambda x: 121-x)
```
%% Cell type:code id: tags:
``` python
df.head()
```
%% Output
Gender BirthYear Cluster LoanPeriod NumberAts 1Ats \\n0 0 21 3 516.0 15 Brusestole \n1 0 45 13 315.0 2 Rollatorer \n2 0 45 13 141.0 8 Rollatorer \n3 0 45 13 142.0 8 Rollatorer \n4 0 45 13 159.0 9 Rollatorer \n\n 2Ats 3Ats 4Ats 5Ats \\n0 ToiletforhøjereStativ Gangborde Gangborde Nødalarmsystemer \n1 Brusestole 0 0 0 \n2 Brusestole RamperMobile Vendehjælpemidler Gangborde \n3 Brusestole RamperMobile Vendehjælpemidler Gangborde \n4 Brusestole RamperMobile Vendehjælpemidler Gangborde \n\n ... 43Ats 44Ats 45Ats 46Ats 47Ats 48Ats 49Ats 50Ats Complete Age \n0 ... 0 0 0 0 0 0 0 0 0 100 \n1 ... 0 0 0 0 0 0 0 0 1 76 \n2 ... 0 0 0 0 0 0 0 0 1 76 \n3 ... 0 0 0 0 0 0 0 0 1 76 \n4 ... 0 0 0 0 0 0 0 0 1 76 \n\n[5 rows x 57 columns]
Gender BirthYear Cluster LoanPeriod NumberAts 1Ats 2Ats 3Ats \\n0 1 32 0 318.0 3 093304 093307 120606 \n1 0 29 0 1666.0 6 093307 093304 120606 \n2 1 37 5 0.0 0 0 0 0 \n3 1 37 5 0.0 0 0 0 0 \n4 1 37 0 0.0 0 0 0 0 \n\n 4Ats 5Ats ... 43Ats 44Ats 45Ats 46Ats 47Ats 48Ats 49Ats 50Ats \\n0 0 0 ... 0 0 0 0 0 0 0 0 \n1 091218 120606 ... 0 0 0 0 0 0 0 0 \n2 0 0 ... 0 0 0 0 0 0 0 0 \n3 0 0 ... 0 0 0 0 0 0 0 0 \n4 0 0 ... 0 0 0 0 0 0 0 0 \n\n Complete Age \n0 0 89 \n1 0 92 \n2 1 84 \n3 1 84 \n4 1 84 \n\n[5 rows x 57 columns]
%% Cell type:code id: tags:
``` python
df.Complete.value_counts()
```
%% Output
1 1543\n0 601\nName: Complete, dtype: int64
%% Cell type:code id: tags:
``` python
import seaborn as sns
var = df['Complete']
varValue = var.value_counts()
plt.figure()
plt.bar(varValue.index, varValue)
plt.xticks(varValue.index, varValue.index.values)
plt.ylabel("Frequency")
plt.title('Complete')
file_name = f"Complete bar.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
plot = sns.scatterplot(data=df, x="Age", y="NumberAts", hue="Complete")
plt.title("Scatter plot of NumberAts vs Age")
fig = plot.get_figure()
file_name = f"Complete scatter NumberAts Age.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
g = sns.FacetGrid(df, col="Complete")
g.map(sns.distplot, "Age", bins=25)
g.fig.suptitle("Number of citizens who complete given age")
g.fig.subplots_adjust(top=.8)
file_name = f"Complete facetgrid age.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches="tight")
```
%% Output
C:\Users\cml\miniconda3\envs\py38-air\lib\site-packages\seaborn\distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
C:\Users\cml\miniconda3\envs\py38-air\lib\site-packages\seaborn\distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
%% Cell type:code id: tags:
``` python
def get_ats_list(df):
all_ats = []
ats_cols = [f"{i}Ats" for i in range(1, cfg.ATS_RESOLUTION+1)]
for ats_col in ats_cols:
for ats_string in df[ats_col]:
for ats in ats_string.split(","):
if ats != "0":
all_ats.append(ats)
return all_ats
ats_no_complete = pd.Series(get_ats_list(df.loc[df['Complete'] == 0]))
ats_complete = pd.Series(get_ats_list(df.loc[df['Complete'] == 1]))
a = pd.DataFrame(ats_no_complete.value_counts()[:20], columns=['No complete quantity'])
b = pd.DataFrame(ats_complete.value_counts()[:20], columns=['Complete quantity'])
ats_df = pd.concat([a, b], axis=1).fillna(0)
ats_df.index.names = ['Ats']
ats_df = ats_df.reset_index()
ats_df['No complete quantity'] = ats_df['No complete quantity'] / len(ats_no_complete)
ats_df['Complete quantity'] = ats_df['Complete quantity'] / len(ats_complete)
```
%% Cell type:code id: tags:
``` python
plt.bar(ats_df["Ats"], ats_df["No complete quantity"], label="No complete")
plt.bar(ats_df["Ats"], ats_df["Complete quantity"], bottom=ats_df["No complete quantity"], label="Complete")
plt.legend()
plt.xticks(rotation=90)
plt.ylabel("Scaled ats usage")
plt.title('Scaled plot of ats usage for complete')
file_name = f"Complete scaled ats usage.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
......
......@@ -349,7 +349,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.8.8-final"
},
"orig_nbformat": 2
},
......
import pandas as pd
import numpy as np
from tools import preprocessor, file_reader, explainer
import config as cfg
import os
import csv
import joblib
from pathlib import Path
def main():
model = file_reader.read_joblib(cfg.COMPLETE_XGB_DIR,
'complete_xgboost.joblib')
input_data = {"Gender": [0],
"BirthYear": [46],
"LoanPeriod": [360],
"Ats": ["093307,222718,181210"]}
new_data_df = pd.DataFrame.from_dict(input_data)
new_data_df['NumberAts'] = len(new_data_df['Ats'][0].split(","))
df = preprocessor.split_cat_columns(new_data_df, col_to_split='Ats',
tag='Ats',
resolution=cfg.ATS_RESOLUTION)
cols_ats = [str(i)+'Ats' for i in range(1, cfg.ATS_RESOLUTION+1)]
header_list = ['Gender', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberAts'] + cols_ats
df = df.reindex(columns=header_list)
df = df.fillna('0')
df['Cluster'] = 14
df['Cluster'] = pd.to_numeric(df['Cluster'])
for i in range(1, cfg.ATS_RESOLUTION+1):
path = Path.joinpath(cfg.PROCESSED_DATA_DIR, 'embeddings')
embedding = file_reader.read_embedding(path, f'complete_{i}Ats.csv')
column = f'{i}Ats'
df[column] = df[column].replace(to_replace=embedding)
df[column] = pd.to_numeric(df[column])
prediction = model.predict(df)
probability = model.predict_proba(df).max()
_, shap_values = explainer.get_shap_tree_explainer(model, X_test=df)
shap_values_flat = [round(float(val), 3) for val in shap_values[0]]
shap_values_dict = dict(zip(df.columns, shap_values_flat))
print(f"Predicted {int(prediction[0])} with probability {round(float(probability), 3)*100}%")
for item, amount in shap_values_dict.items():
print("{} ({})".format(item, amount))
if __name__ == "__main__":
main()
\ No newline at end of file
import pandas as pd
import numpy as np
from tools import preprocessor, file_reader, explainer
import config as cfg
from pathlib import Path
def main():
model = file_reader.read_joblib(cfg.COMPLETE_XGB_DIR,
'complete_xgboost.joblib')
converters = {str(i)+'Ats':str for i in range(1, cfg.ATS_RESOLUTION+1)}
df = file_reader.read_csv(cfg.TESTS_FILES_DIR,
'test_citizens.csv',
converters=converters)
for i in range(1, cfg.ATS_RESOLUTION+1):
path = Path.joinpath(cfg.PROCESSED_DATA_DIR, 'embeddings')
embedding = file_reader.read_embedding(path, f'complete_{i}Ats.csv')
column = f'{i}Ats'
df[column] = df[column].replace(to_replace=embedding)
df[column] = pd.to_numeric(df[column])
test_citizen = pd.DataFrame(df.iloc[0]).T
print(test_citizen)
prediction = model.predict(test_citizen)
probability = model.predict_proba(test_citizen).max()
_, shap_values = explainer.get_shap_tree_explainer(model, X_test=test_citizen)
shap_values_flat = [round(float(val), 3) for val in shap_values[0]]
shap_values_dict = dict(zip(test_citizen.columns, shap_values_flat))
print(f"Predicted {int(prediction[0])} with probability {round(float(probability), 3)*100}%")
for item, amount in shap_values_dict.items():
print("{} ({})".format(item, amount))
if __name__ == "__main__":
main()
\ No newline at end of file
import pandas as pd
import numpy as np
from tools import preprocessor, file_reader
import config as cfg
import os
import csv
import joblib
from pathlib import Path
def main():
model = file_reader.read_joblib(cfg.COMPLIANCE_XGB_DIR,
'compliance_xgboost.joblib')
for gender in range(0, 2):
input_data = {"Gender": [gender],
"BirthYear": [72],
"LoanPeriod": [360],
"Ats": ["222718,093307,181210"]}
new_data_df = pd.DataFrame.from_dict(input_data)
new_data_df['NumberAts'] = len(new_data_df['Ats'][0].split(","))
df = preprocessor.split_cat_columns(new_data_df, col_to_split='Ats',
tag='Ats',
resolution=cfg.ATS_RESOLUTION)
cols_ats = [str(i)+'Ats' for i in range(1, cfg.ATS_RESOLUTION+1)]
header_list = ['Gender', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberAts'] + cols_ats
df = df.reindex(columns=header_list)
df = df.fillna('0')
df['Cluster'] = 0
df['Cluster'] = pd.to_numeric(df['Cluster'])
for i in range(1, cfg.ATS_RESOLUTION+1):
path = Path.joinpath(cfg.PROCESSED_DATA_DIR, 'embeddings')
embedding = file_reader.read_embedding(path, f'complete_{i}Ats.csv')
column = f'{i}Ats'
df[column] = df[column].replace(to_replace=embedding)
df[column] = pd.to_numeric(df[column])
prediction = model.predict(df)
probability = model.predict_proba(df).max()
print(f"Using gender {gender}, predicted " +
f"{int(prediction[0])} with probability {round(float(probability), 3)*100}%")
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -110,13 +110,17 @@ def get_all_ex():
@app.get("/ats/hmi")
def get_ats_by_hmi(hmi: str):
ats = read_csv('ats.csv')[hmi]
return {"hmi": hmi, "ats": ats}
ats = read_csv('ats.csv')
if hmi in ats:
return {"hmi": hmi, "ats": ats[hmi]}
raise HTTPException(status_code=404, detail="Not found")
@app.get("/ex/ex_id")
def get_ex_by_id(ex_id: str):
ex = read_csv('ex.csv')[ex_id]
return {"ex_id": ex_id, "ex": ex}
ex = read_csv('ex.csv')
if ex_id in ex:
return {"hmi": ex_id, "ats": ex[ex_id]}
raise HTTPException(status_code=404, detail="Not found")
@app.post('/login')
def login(user: User, Authorize: AuthJWT = Depends()):
......
......@@ -4,7 +4,7 @@ import config as cfg
import pandas as pd
from tools import file_reader, file_writer, explainer
from utility import metrics
from sklearn.metrics import accuracy_score, precision_score,
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import recall_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
......@@ -20,15 +20,15 @@ def main():
model_dir = cfg.COMPLETE_XGB_DIR
target_name = "Complete"
elif case == "Compliance":
df = file_reader.read_csv(DATA_DIR, 'compliance_count.csv')
df = file_reader.read_csv(DATA_DIR, 'compliance_emb.csv')
model_dir = cfg.COMPLIANCE_XGB_DIR
target_name = "Compliance"
elif case == "Fall":
df = file_reader.read_csv(DATA_DIR, 'fall_count.csv')
df = file_reader.read_csv(DATA_DIR, 'fall_emb.csv')
model_dir = cfg.FALL_XGB_DIR
target_name = "Fall"
else:
df = file_reader.read_csv(DATA_DIR, 'fall_test_count.csv')
df = file_reader.read_csv(DATA_DIR, 'fall_test_emb.csv')
model_dir = cfg.FALL_TEST_XGB_DIR
target_name = "Fall"
......@@ -36,12 +36,12 @@ def main():
X = df.drop([target_name], axis=1)
y = df[target_name]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
neg, pos = np.bincount(y)
scale_pos_weight = neg / pos
params = {"n_estimators": 800,
params = {"n_estimators": 400,
"objective": "binary:logistic",
"scale_pos_weight": scale_pos_weight,
"use_label_encoder": False,
......@@ -79,19 +79,21 @@ def main():
valid_pre.append(precision_score(y_valid_split, y_valid_scores))
valid_recall.append(recall_score(y_valid_split, y_valid_scores))
valid_roc_auc.append(roc_auc_score(y_valid_split, y_valid_pred.iloc[valid_index]))
print(f"Accuracy: {np.around(np.mean(valid_acc), decimals=3)}")
print(f"Precision: {np.around(np.mean(valid_pre), decimals=3)}")
print(f"Recall: {np.around(np.mean(valid_recall), decimals=3)}")
print(f"ROC AUC: {np.around(np.mean(valid_roc_auc), decimals=3)}\n")
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)[:,1]
file_writer.write_cm_plot(y_test, y_pred, cfg.REPORTS_PLOTS_DIR,
f'{case.lower()}_xgb_cm.pdf', case)
file_writer.write_joblib(model, model_dir, f'{case.lower()}_xgboost.joblib')
print(f"Accuracy: {np.around(accuracy_score(y_test, y_pred), decimals=3)}")
print(f"Precision: {np.around(precision_score(y_test, y_pred), decimals=3)}")
print(f"Recall: {np.around(recall_score(y_test, y_pred), decimals=3)}")
print(f"ROC AUC: {np.around(roc_auc_score(y_test, y_proba), decimals=3)}\n")
feature_names = X.columns
shap_explainer, shap_values = explainer.get_shap_tree_explainer(model, X_test=X_test)
importance_df = pd.DataFrame()
importance_df['feature'] = feature_names
importance_df['shap_values'] = np.around(abs(np.array(shap_values)[:,:]).mean(0), decimals=3)
......
Gender,BirthYear,Cluster,LoanPeriod,NumberAts,1Ats,2Ats,3Ats,4Ats,5Ats,6Ats,7Ats,8Ats,9Ats,10Ats,11Ats,12Ats,13Ats,14Ats,15Ats,16Ats,17Ats,18Ats,19Ats,20Ats,21Ats,22Ats,23Ats,24Ats,25Ats,26Ats,27Ats,28Ats,29Ats,30Ats,31Ats,32Ats,33Ats,34Ats,35Ats,36Ats,37Ats,38Ats,39Ats,40Ats,41Ats,42Ats,43Ats,44Ats,45Ats,46Ats,47Ats,48Ats,49Ats,50Ats
0,39,14,255.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,35,4,157.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,24,8,14.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,30,0,84.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,38,8,2650.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,45,3,873.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,31,13,2459.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,38,8,1699.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,23,6,718.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,37,14,2619.0,12,181810,120606,123103,093307,093307,120603,093307,091203,091203,120603,120606,120724,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment