Commit 3326a6e1 authored by Christian Marius Lillelund's avatar Christian Marius Lillelund
Browse files

updated air demo

parent 2ca5a0c8
Pipeline #49318 passed with stage
in 3 minutes and 46 seconds
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import datetime as dt
import matplotlib.pyplot as plt
from tools import preprocessor, data_loader, file_writer
import config as cfg
import tensorflow as tf
import seaborn as sns
from pathlib import Path
tf.get_logger().setLevel('ERROR')
# Set dataset
filename = "fall_test.csv"
ex = {str(i)+'Ex':str for i in range(1, cfg.EX_RESOLUTION+1)}
ats = {str(i)+'Ats':str for i in range(1, cfg.ATS_RESOLUTION+1)}
converters = {**ex, **ats}
# Load data
dl = data_loader.FallDataLoader(file_name=filename, converters=converters).load_data()
X, y = dl.get_data()
df = pd.concat([X, y], axis=1)
# Add age feature
df['Age'] = df['BirthYear'].apply(lambda x: 121-x)
```
%% Cell type:code id: tags:
``` python
df.MeanEvaluation.max()
```
%% Output
6.0
%% Cell type:code id: tags:
``` python
df.head()
```
%% Output
Gender BirthYear Cluster LoanPeriod NumberSplit NumberScreening \\n0 1 32 0 318.0 0 0 \n1 1 32 0 360.0 0 1 \n2 0 29 0 1666.0 0 0 \n3 0 29 5 1694.0 0 1 \n4 0 29 5 1705.0 0 2 \n\n NumberWeeks MeanEvaluation NumberFalls NumberTraining ... \\n0 0.00 0.0 0 1 ... \n1 16.57 4.0 0 8 ... \n2 0.00 0.0 0 1 ... \n3 3.43 4.0 0 6 ... \n4 1.57 3.7 0 3 ... \n\n 2Ex 3Ex \\n0 LiggendeTilSiddendeSelvhjulpet RulleOmPåSidenBeggeHænderPåKrop \n1 SelvstændigGangfunktion StåPåTæer \n2 SkalBrugeArmlæn GårMedGangredskabOgStøtte \n3 StåMedSamledeFødder MereEnd16Sek \n4 LiggendeTilSiddendeSelvhjulpet MindreEnd8Oprejsninger \n\n 4Ex 5Ex \\n0 StåMedSamledeFødder MereEnd16Sek \n1 8Til13Oprejsninger GårUdenGangredskabOgUdenStøtte \n2 MereEnd16Sek StåMedSamledeFødder \n3 RulleOmPåSidenBeggeHænderPåKrop RygliggendeBækkenløftEtBenStrakt \n4 GårMedGangredskab 8Til12Sek \n\n 6Ex 7Ex \\n0 MereEnd13Oprejsninger Gå4SkridtBaglæns \n1 SiddeUdenStøtteMereEnd60Sekunder 0 \n2 SiddeUdenStøtte10Sekunder 0 \n3 SiddeUdenStøtteMereEnd60Sekunder LiggendeTilSiddendeBeggeHænderPåKop \n4 StåPåTæer Gå4SkridtBaglæns \n\n 8Ex 9Ex Fall \\n0 GårUdenGangredskabOgUdenStøtte SiddeUdenStøtteMereEnd60Sekunder 0 \n1 0 0 0 \n2 0 0 1 \n3 0 0 0 \n4 SiddeUdenStøtteMereEnd60Sekunder 0 1 \n\n Age \n0 89 \n1 89 \n2 92 \n3 92 \n4 92 \n\n[5 rows x 80 columns]
%% Cell type:code id: tags:
``` python
df.Fall.value_counts()
```
%% Output
0 2748\n1 469\nName: Fall, dtype: int64
%% Cell type:code id: tags:
``` python
falls = df.loc[df.Fall == 1]
df.filter(regex='((\d+)[Ex])\w+', axis=1)
```
%% Output
1Ex 2Ex \\n0 RygliggendeBækkenløft LiggendeTilSiddendeSelvhjulpet \n1 8Til12Sek SelvstændigGangfunktion \n2 RulleOmPåSidenMedHjælp SkalBrugeArmlæn \n3 SkalBrugeArmlæn StåMedSamledeFødder \n4 RygliggendeBækkenløft LiggendeTilSiddendeSelvhjulpet \n... ... ... \n3212 RygliggendeBækkenløft RulleOmPåSiden \n3213 StåMedSamledeFødder RulleOmPåSiden \n3214 RygliggendeBækkenløft RulleOmPåSiden \n3215 SkalBrugeArmlæn StåMedSamledeFødder \n3216 RulleOmPåSiden LiggendeTilSiddendeSelvhjulpet \n\n 3Ex 4Ex \\n0 RulleOmPåSidenBeggeHænderPåKrop StåMedSamledeFødder \n1 StåPåTæer 8Til13Oprejsninger \n2 GårMedGangredskabOgStøtte MereEnd16Sek \n3 MereEnd16Sek RulleOmPåSidenBeggeHænderPåKrop \n4 MindreEnd8Oprejsninger GårMedGangredskab \n... ... ... \n3212 LiggendeTilSiddendeSelvhjulpet SkalBrugeArmlæn \n3213 LiggendeTilSiddendeSelvhjulpet GårMedGangredskab \n3214 LiggendeTilSiddendeSelvhjulpet MindreEnd8Oprejsninger \n3215 SamleKuglepenOpFraGulvet SiddeUdenStøtte10Sekunder \n3216 8Til12Sek SelvstændigGangfunktion \n\n 5Ex 6Ex \\n0 MereEnd16Sek MereEnd13Oprejsninger \n1 GårUdenGangredskabOgUdenStøtte SiddeUdenStøtteMereEnd60Sekunder \n2 StåMedSamledeFødder SiddeUdenStøtte10Sekunder \n3 RygliggendeBækkenløftEtBenStrakt SiddeUdenStøtteMereEnd60Sekunder \n4 8Til12Sek StåPåTæer \n... ... ... \n3212 StåMedSamledeFødder MereEnd16Sek \n3213 RygliggendeBækkenløftEtBenStrakt SiddeUdenStøtteMereEnd60Sekunder \n3214 StåUdenStøtte GårMedGangredskab \n3215 0 0 \n3216 StåPåTæer 8Til13Oprejsninger \n\n 7Ex 8Ex \\n0 Gå4SkridtBaglæns GårUdenGangredskabOgUdenStøtte \n1 0 0 \n2 0 0 \n3 LiggendeTilSiddendeBeggeHænderPåKop 0 \n4 Gå4SkridtBaglæns SiddeUdenStøtteMereEnd60Sekunder \n... ... ... \n3212 SiddeUdenStøtteMereEnd60Sekunder 0 \n3213 0 0 \n3214 12Komma1Til16Sek SelvstændigGangfunktion \n3215 0 0 \n3216 SiddeUdenStøtte30Til60Sekunder GårUdenGangredskabOgUdenStøtte \n\n 9Ex \n0 SiddeUdenStøtteMereEnd60Sekunder \n1 0 \n2 0 \n3 0 \n4 0 \n... ... \n3212 0 \n3213 0 \n3214 SiddeUdenStøtteMereEnd60Sekunder \n3215 0 \n3216 0 \n\n[3217 rows x 9 columns]
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
import seaborn as sns
var = df['Fall']
varValue = var.value_counts()
plt.figure()
plt.bar(varValue.index, varValue)
plt.xticks(varValue.index, varValue.index.values)
plt.ylabel("Frequency")
plt.title('Fall')
file_name = f"Fall test bar.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
plot = sns.scatterplot(data=df, x="Age", y="NumberAts", hue="Fall")
plt.title("Scatter plot of NumberAts vs Age")
fig = plot.get_figure()
file_name = f"Fall test scatter NumberAts Age.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
g = sns.FacetGrid(df, col = "Fall", margin_titles=True)
g.map(sns.distplot, "Age", bins = 25)
g.fig.suptitle("Number of citizens who fall given age")
file_name = f"Fall test facetgrid age.pdf"
g.fig.subplots_adjust(top=.8)
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
`distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
`distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
%% Cell type:code id: tags:
``` python
def get_ats_list(df):
all_ats = []
ats_cols = [f"{i}Ats" for i in range(1, cfg.ATS_RESOLUTION+1)]
for ats_col in ats_cols:
for ats_string in df[ats_col]:
for ats in ats_string.split(","):
if ats != "0":
all_ats.append(ats)
return all_ats
ats_no_fall = pd.Series(get_ats_list(df.loc[df['Fall'] == 0]))
ats_fall = pd.Series(get_ats_list(df.loc[df['Fall'] == 1]))
a = pd.DataFrame(ats_no_fall.value_counts(), columns=['No fall quantity'])
b = pd.DataFrame(ats_fall.value_counts(), columns=['Fall quantity'])
ats_df = pd.concat([a, b], axis=1).fillna(0)
ats_df.index.names = ['Ats']
ats_df = ats_df.reset_index()
ats_df['No fall quantity'] = ats_df['No fall quantity'] / len(ats_no_fall)
ats_df['Fall quantity'] = ats_df['Fall quantity'] / len(ats_fall)
ats_df = ats_df.iloc[:20]
```
%% Cell type:code id: tags:
``` python
plt.bar(ats_df["Ats"], ats_df["No fall quantity"], label="No fall")
plt.bar(ats_df["Ats"], ats_df["Fall quantity"], bottom=ats_df["No fall quantity"], label="Fall")
plt.legend()
plt.xticks(rotation=90)
plt.ylabel("Scaled ats usage")
plt.title('Scaled plot of ats usage for fall test')
file_name = f"Fall test scaled ats usage.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
def get_ex_list(df):
all_ex = []
ex_cols = [f"{i}Ex" for i in range(1, cfg.EX_RESOLUTION+1)]
for ex_col in ex_cols:
for ex_string in df[ex_col]:
for ex in ex_string.split(","):
if ex != "0":
all_ex.append(ex)
return all_ex
ex_no_fall = pd.Series(get_ex_list(df.loc[df['Fall'] == 0]))
ex_fall = pd.Series(get_ex_list(df.loc[df['Fall'] == 1]))
a = pd.DataFrame(ex_no_fall.value_counts(), columns=['No fall quantity'])
b = pd.DataFrame(ex_fall.value_counts(), columns=['Fall quantity'])
ex_df = pd.concat([a, b], axis=1).fillna(0)
ex_df.index.names = ['Ex']
ex_df = ex_df.reset_index()
ex_df['No fall quantity'] = ex_df['No fall quantity'] / len(ex_no_fall)
ex_df['Fall quantity'] = ex_df['Fall quantity'] / len(ex_fall)
ex_df = ex_df.iloc[:20]
```
%% Cell type:code id: tags:
``` python
plt.figure(num=None, figsize=(8, 6), dpi=80, facecolor='w', edgecolor='k')
plt.bar(ex_df["Ex"], ex_df["No fall quantity"], label="No fall")
plt.bar(ex_df["Ex"], ex_df["Fall quantity"], bottom=ex_df["No fall quantity"], label="Fall")
plt.legend()
plt.xticks(rotation=90)
plt.ylabel("Scaled ex usage")
plt.title('Scaled plot of ex usage for fall test')
file_name = f"Fall test scaled ex usage.pdf"
plt.savefig(Path.joinpath(cfg.REPORTS_PLOTS_DIR, file_name), dpi=300, bbox_inches = "tight")
```
%% Output
%% Cell type:code id: tags:
``` python
ats_df['Diff'] = ats_df['Fall quantity'] - ats_df['No fall quantity']
mask = ats_df['Diff'].gt(0)
ats_df_sorted = pd.concat([ats_df[mask].sort_values('Diff', ascending=False),
ats_df[~mask].sort_values('Diff', ascending=False)], ignore_index=True)
ats_df_sorted = ats_df_sorted.round(4)
file_writer.write_csv(ats_df_sorted, cfg.REPORTS_DIR, 'ats_df_sorted.csv')
```
%% Cell type:code id: tags:
``` python
ex_df['Diff'] = ex_df['Fall quantity'] - ex_df['No fall quantity']
mask = ex_df['Diff'].gt(0)
ex_df_sorted = pd.concat([ex_df[mask].sort_values('Diff', ascending=False),
ex_df[~mask].sort_values('Diff', ascending=False)], ignore_index=True)
ex_df_sorted = ex_df_sorted.round(4)
file_writer.write_csv(ex_df_sorted, cfg.REPORTS_DIR, 'ex_df_sorted.csv')
```
......
......@@ -12,7 +12,7 @@ tensorflow==2.4.1
openpyxl==3.0.6
xgboost==1.3.3
keras-tuner==1.0.2
shap==0.39.0
shap==0.37.0
scikeras==0.2.1
fastapi==0.63.0
uvicorn==0.13.4
......
......@@ -14,7 +14,7 @@ def main():
df = file_reader.read_csv(cfg.INTERIM_DATA_DIR, 'screenings.csv',
converters={'CitizenId': str})
df = preprocessor.split_cat_columns(df, col='Ats', tag='Ats',
df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats',
resolution=cfg.ATS_RESOLUTION)
df = feature_maker.make_complete_feature(df)
......@@ -51,8 +51,8 @@ def main():
param_grid = [
{
'cluster_maker__init': ['random', 'Huang', 'Cao'],
'cluster_maker__n_clusters': [2, 5, 10, 20, 30, 40, 50, 100],
'cluster_maker__n_init': [1, 5, 10, 20]
'cluster_maker__n_clusters': [2, 5, 10, 20, 30, 40, 50],
'cluster_maker__n_init': [1, 5, 10, 15, 20]
}
]
......@@ -67,13 +67,13 @@ def main():
random_search.fit(X, y)
print('\n All results:')
print('\nAll results:')
print(random_search.cv_results_)
print('\n Best estimator:')
print('\nBest estimator:')
print(random_search.best_estimator_)
print('\n Best normalized gini score for %d-fold search with %d parameter combinations:' % (5, 5))
print('\nBest normalized gini score for %d-fold search with %d parameter combinations:' % (5, 5))
print(random_search.best_score_ * 2 - 1)
print('\n Best hyperparameters:')
print('\nBest hyperparameters:')
print(random_search.best_params_)
results = pd.DataFrame(random_search.cv_results_)
file_writer.write_csv(results, cfg.REPORTS_DIR, 'kmodes-settings-random-grid-search-results.csv')
......
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
120606,093307,242103,120606,093307,181503,093307,120606,091218,091212,093307,122203,120606,120606,120606,120606,120606,120606,181210,222704
120606,120606,091218,120606,222718,122203,091218,120606,093307,091212,122203,120606,121212,093307,120606,120606,242103,093307,043306,222403
0,120606,093307,180903,222718,122409,120606,120724,120606,043306,093307,242103,093307,093307,120606,093307,091218,093307,122203,093307
0,122203,120606,093307,122203,181006,120727,093307,123621,181006,091218,093307,120606,181210,120606,222718,123103,091218,181024,222718
0,180903,123109,123106,043303,122203,120306,093307,242103,181224,093307,181006,122203,043306,120606,043306,120606,120606,181210,120606
0,120606,181210,181210,043306,180903,180903,122203,091203,120306,120603,180903,043303,120606,120606,181218,120606,043303,043306,120306
0,120606,043306,122203,181210,183015,122203,181006,123621,120306,242103,181021,183018,091203,120606,122218,120606,222718,093307,181210
0,122203,043306,181218,181218,090305,043303,120606,122203,122203,180903,120606,181006,091218,043303,043303,181003,093307,120306,043306
0,043303,181218,181006,123603,090305,120306,122203,043306,181003,222718,180903,180903,222718,120606,123621,181003,043303,120306,091218
0,091203,091203,122203,091218,183015,120306,120606,181218,181006,222718,120316,091218,123103,120606,123621,093307,091203,091218,123604
0,122203,123112,091218,091203,122303,120306,093307,043303,122203,222718,120724,183015,122203,120606,091203,180907,122203,093307,123621
0,122203,242103,122203,091218,120606,091218,181210,181233,180903,122203,043303,120606,043303,120606,043306,120606,122203,091218,091203
0,122203,122203,122203,091218,091218,122203,181218,091203,181003,043303,181503,222718,180903,0,091203,120606,043303,091218,181503
0,091203,123604,043303,242103,091218,122409,123603,123603,122203,043303,181210,181210,120606,0,181210,091218,120606,043303,122203
0,043303,123621,043303,242103,122203,181210,123621,181210,120306,123103,043306,043306,181228,0,123603,120606,0,091218,043303
0,222718,123621,120612,120606,122203,181218,123621,181218,120306,123103,123103,123106,123103,0,122203,180903,0,091218,123603
0,122203,043306,120606,120612,180903,091203,091203,123621,183010,222718,123103,091203,123103,0,043303,121212,0,122303,123621
0,242103,043306,122203,222718,122203,120606,122218,091203,043609,0,123112,181024,120606,0,180315,091218,0,122203,181228
0,093307,043306,122203,043306,181003,043609,123603,123621,093307,0,180903,122203,123103,0,122203,043303,0,120612,123612
0,122203,043306,120724,043303,122203,120606,091203,242103,222718,0,0,122203,123103,0,043303,043303,0,093307,0
0,181210,123603,120715,120606,043303,181024,122218,122203,222718,0,0,043303,120606,0,123612,181006,0,091236,0
0,043306,123621,122203,0,122303,222718,123621,043303,222718,0,0,122203,0,0,123621,121212,0,180903,0
0,123103,123621,043303,0,123103,183015,123612,122218,0,0,0,043303,0,0,123621,121212,0,120612,0
0,123103,181210,043303,0,120606,0,122203,122203,0,0,0,222718,0,0,123103,120603,0,222718,0
0,123103,043306,043303,0,090903,0,122218,120606,0,0,0,0,0,0,123106,091218,0,043303,0
0,123103,091209,043303,0,120606,0,180315,0,0,0,0,0,0,0,122203,120606,0,093307,0
0,123621,0,043303,0,120612,0,043303,0,0,0,0,0,0,0,091203,120606,0,0,0
0,0,0,043303,0,120612,0,043303,0,0,0,0,0,0,0,043303,222718,0,0,0
0,0,043303,043303,0,123109,0,0,0,0,0,0,0,0,0,123103,122203,0,0,0
0,0,0,043303,0,093307,0,0,0,0,0,0,0,0,0,122203,043303,0,0,0
0,0,0,043303,0,122306,0,0,0,0,0,0,0,0,0,122203,180903,0,0,0
0,0,0,181006,0,120603,0,0,0,0,0,0,0,0,0,0,181006,0,0,0
0,0,0,122203,0,043303,0,0,0,0,0,0,0,0,0,0,180907,0,0,0
0,0,0,093307,0,043303,0,0,0,0,0,0,0,0,0,0,123103,0,0,0
0,0,0,043303,0,181210,0,0,0,0,0,0,0,0,0,0,123103,0,0,0
0,0,0,122203,0,043306,0,0,0,0,0,0,0,0,0,0,242103,0,0,0
0,0,0,043303,0,181218,0,0,0,0,0,0,0,0,0,0,122203,0,0,0
0,0,0,043303,0,122203,0,0,0,0,0,0,0,0,0,0,043306,0,0,0
0,0,0,043303,0,043303,0,0,0,0,0,0,0,0,0,0,121212,0,0,0
0,0,0,043303,0,242103,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,181218,0,123103,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,123103,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,091218,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,242103,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,091203,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,043306,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,181228,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,122203,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,043303,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,0,0,0,0,043303,0,0,0,0,0,0,0,0,0,0,0,0,0,0
......@@ -5,7 +5,7 @@ import os
import csv
import joblib
import pandas as pd
import json
import io
from typing import List, Optional
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
......@@ -14,6 +14,7 @@ from fastapi_jwt_auth.exceptions import AuthJWTException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
ATS_RESOLUTION = 50
......@@ -88,41 +89,61 @@ class InputData(pydantic.BaseModel):
class OutputData(pydantic.BaseModel):
Prediction: int
Probability: float
ClusterId: int
ClusterValues: List
ShapValues: dict
@app.get('/')
def index():
return {'message': f'AIR API v. 0.1'}
@app.get('/user', dependencies=[Depends(JWTBearer())])
@app.get('/user', dependencies=[Depends(JWTBearer())], tags=["login"])
def user(Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
current_user = Authorize.get_jwt_subject()
return {"user": current_user}
@app.get('/ats')
@app.get('/ats', tags=["data"])
def get_all_ats():
return read_csv('ats.csv')
@app.get('/ex')
@app.get('/ex', tags=["data"])
def get_all_ex():
return read_csv('ex.csv')
@app.get("/ats/hmi")
@app.get("/ats/hmi", tags=["data"])
def get_ats_by_hmi(hmi: str):
ats = read_csv('ats.csv')
if hmi in ats:
return {"hmi": hmi, "ats": ats[hmi]}
raise HTTPException(status_code=404, detail="Not found")
@app.get("/ex/ex_id")
@app.get("/ex/ex_id", tags=["data"])
def get_ex_by_id(ex_id: str):
ex = read_csv('ex.csv')
if ex_id in ex:
return {"hmi": ex_id, "ats": ex[ex_id]}
raise HTTPException(status_code=404, detail="Not found")
@app.post('/login')
@app.get("/cluster", tags=["data"])
def get_all_clusters():
df = read_dataframe('clusters.csv')
stream = io.StringIO()
df.to_csv(stream, index = False)
response = StreamingResponse(iter([stream.getvalue()]),
media_type="text/csv")
response.headers["Content-Disposition"] = "attachment; filename=clusters.csv"
return response
@app.get("/cluster/cluster_id", tags=["data"])
def get_cluster_by_id(cluster_id: int):
df = read_dataframe('clusters.csv')
if df.shape[1] >= cluster_id:
cluster = list(df.iloc[:, cluster_id])
return {"cluster": cluster}
raise HTTPException(status_code=404, detail="Not found")
@app.post('/login', tags=["login"])
def login(user: User, Authorize: AuthJWT = Depends()):
if user.username != "test" or user.password != "test":
raise HTTPException(status_code=401, detail="Bad username or password")
......@@ -130,14 +151,14 @@ def login(user: User, Authorize: AuthJWT = Depends()):
refresh_token = Authorize.create_refresh_token(subject=user.username)
return {"access_token": access_token, "refresh_token": refresh_token}
@app.post('/refresh', dependencies=[Depends(JWTBearer())])
@app.post('/refresh', dependencies=[Depends(JWTBearer())], tags=["login"])
def refresh(Authorize: AuthJWT = Depends()):
Authorize.jwt_refresh_token_required()
current_user = Authorize.get_jwt_subject()
new_access_token = Authorize.create_access_token(subject=current_user)
return {"access_token": new_access_token}
@app.post('/predict_complete', response_model=OutputData)
@app.post('/predict_complete', response_model=OutputData, tags=["ai"])
def predict_complete(incoming_data: InputData):
data = incoming_data.dict()
df = prepare_data(data, 'complete')
......@@ -146,15 +167,20 @@ def predict_complete(incoming_data: InputData):
prediction = model.predict(df)
probability = model.predict_proba(df).max()
cluster_id = int(df.iloc[0]['Cluster'])
clusters = read_dataframe('clusters.csv')
cluster_values = list(clusters.iloc[:, cluster_id])
shap_values = get_shap_values(model, X_test=df)
return {
'Prediction': int(prediction[0]),
'Probability': float(probability),
'ClusterId': int(df.iloc[0]['Cluster']),
'ClusterValues': cluster_values,
'ShapValues': shap_values
}
@app.post('/predict_compliance', response_model=OutputData)
@app.post('/predict_compliance', response_model=OutputData, tags=["ai"])
def predict_compliance(incoming_data: InputData):
data = incoming_data.dict()
df = prepare_data(data, 'compliance')
......@@ -163,15 +189,20 @@ def predict_compliance(incoming_data: InputData):
prediction = model.predict(df)
probability = model.predict_proba(df).max()
cluster_id = int(df.iloc[0]['Cluster'])
clusters = read_dataframe('clusters.csv')
cluster_values = list(clusters.iloc[:, cluster_id])
shap_values = get_shap_values(model, X_test=df)
return {
'Prediction': int(prediction[0]),
'Probability': float(probability),
'ClusterId': int(df.iloc[0]['Cluster']),
'ClusterValues': cluster_values,
'ShapValues': shap_values
}
@app.post('/predict_fall', response_model=OutputData)
@app.post('/predict_fall', response_model=OutputData, tags=["ai"])
def predict_fall(incoming_data: InputData):
data = incoming_data.dict()
df = prepare_data(data, 'fall')
......@@ -180,11 +211,16 @@ def predict_fall(incoming_data: InputData):
prediction = model.predict(df)
probability = model.predict_proba(df).max()
cluster_id = int(df.iloc[0]['Cluster'])
clusters = read_dataframe('clusters.csv')
cluster_values = list(clusters.iloc[:, cluster_id])
shap_values = get_shap_values(model, X_test=df)
return {
'Prediction': int(prediction[0]),
'Probability': float(probability),
'ClusterId': int(df.iloc[0]['Cluster']),
'ClusterValues': cluster_values,
'ShapValues': shap_values
}
......@@ -257,6 +293,13 @@ def read_csv(filename:str) -> any:
reader = csv.reader(f)
csv_dict = {rows[0]:rows[1] for rows in reader}
return csv_dict
def read_dataframe(filename:str) -> pd.DataFrame:
dir_path = os.path.dirname(os.path.realpath(__file__))
converters = {str(i):str for i in range(1, 20)}
df = pd.read_csv(f'{dir_path}/data/{filename}',
converters=converters)
return df
if __name__ == "__main__":
uvicorn.run(app, port=8000, host="0.0.0.0")
\ No newline at end of file
......@@ -7,13 +7,13 @@ from typing import List
from kmodes import kmodes
from tools import file_reader, file_writer, preprocessor
USE_ATS_NAMES = True
USE_ATS_NAMES = False
def main():
df = file_reader.read_csv(cfg.INTERIM_DATA_DIR, 'screenings.csv',
converters={'CitizenId': str})
df = preprocessor.split_cat_columns(df, col='Ats', tag='Ats',
df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats',
resolution=cfg.ATS_RESOLUTION)
if USE_ATS_NAMES:
......@@ -40,8 +40,7 @@ def main():
model.cluster_centroids_)]))
file_writer.write_joblib(model, cfg.CLUSTERS_DIR, 'km.joblib')
for i in range(0, 20, 3):
file_writer.write_csv(cluster_centroids.iloc[:,i:i+3], cfg.INTERIM_DATA_DIR, f'cluster_centroids_{i}.csv')
file_writer.write_csv(cluster_centroids, cfg.INTERIM_DATA_DIR, f'cluster_centroids.csv')
file_writer.write_csv(clusters, cfg.INTERIM_DATA_DIR, 'cl.csv')
if __name__ == '__main__':
......
......@@ -186,7 +186,6 @@ def make_fall_emb():
file_writer.write_csv(df, cfg.PROCESSED_DATA_DIR, 'fall_emb.csv')
def make_fall_test_emb():
case = 'Fall'
ex = {str(i)+'Ex':str for i in range(1, cfg.EX_RESOLUTION+1)}
ats = {str(i)+'Ats':str for i in range(1, cfg.ATS_RESOLUTION+1)}
converters = {**ex, **ats}
......@@ -194,13 +193,27 @@ def make_fall_test_emb():
f'fall_test.csv',
converters=converters)
emb_cols = df.filter(regex='((\d+)[Ats|Ex])\w+', axis=1)
n_numerical_cols = df.shape[1] - emb_cols.shape[1] - 1
ats_cols = [str(i)+'Ats' for i in range(1, cfg.ATS_RESOLUTION+1)]
ex_cols = [str(i)+'Ex' for i in range(1, cfg.EX_RESOLUTION+1)]
df_ats_to_enc = df.filter(regex=f'Fall|((\d+)[Ats])\w+', axis=1)
df_ats_to_enc = df_ats_to_enc.drop(['NumberFalls'], axis=1)
df_to_enc = df.iloc[:,n_numerical_cols:]
df_ex_to_enc = df.filter(regex=f'Fall|((\d+)[Ex])\w+', axis=1)
df_ex_to_enc = df_ex_to_enc.drop(['NumberFalls'], axis=1)
ats_enc = encode_dataframe(df_ats_to_enc, 'Fall')
ex_enc = encode_dataframe(df_ex_to_enc, 'Fall')
df = df.drop(ats_cols + ex_cols, axis=1)
df = pd.concat([df.drop('Fall', axis=1), ats_enc, ex_enc, df.pop('Fall')], axis=1)
file_writer.write_csv(df, cfg.PROCESSED_DATA_DIR, 'fall_test_emb.csv')
def encode_dataframe(df_to_enc, case):
target_name = case
train_ratio = 0.9
X_train, X_val, y_train, y_val, labels = preprocessor.prepare_data_for_embedder(df_to_enc,
target_name,
train_ratio)
......@@ -225,12 +238,9 @@ def make_fall_test_emb():
network.save_labels(labels)
network.make_visualizations_from_network(extension='png')
emb_cols = df.filter(regex='((\d+)[Ats|Ex])\w+', axis=1)
n_numerical_cols = df.shape[1] - emb_cols.shape[1] - 1
embedded_df = df.iloc[:, n_numerical_cols:df.shape[1]-1]
for index in range(embedded_df.shape[1]):
column = embedded_df.columns[index]
df_to_enc = df_to_enc.drop('Fall', axis=1)
for index in range(df_to_enc.shape[1] - 1):
column = df_to_enc.columns[index]
labels_column = labels[index]
embeddings_column = embedded_weights[index]
pca = PCA(n_components=1)
......@@ -238,11 +248,11 @@ def make_fall_test_emb():
y_array = np.concatenate(Y)
mapping = dict(zip(labels_column.classes_, y_array))
file_writer.write_mapping(mapping,
Path.joinpath(cfg.PROCESSED_DATA_DIR, 'embeddings'),
f'fall_test_{column}.csv')
df[column] = df[column].replace(to_replace=mapping)
file_writer.write_csv(df, cfg.PROCESSED_DATA_DIR, 'fall_test_emb.csv')
Path.joinpath(cfg.PROCESSED_DATA_DIR, 'embeddings'),
f'fall_test_{column}.csv')
df_to_enc[column] = df_to_enc[column].replace(to_replace=mapping)
return df_to_enc
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -33,7 +33,7 @@ def make_complete_case(df, clusters):
df = preprocessor.replace_ats_values(df)
file_writer.write_csv(df, cfg.PROCESSED_DATA_DIR, f'complete.csv')
def make_compliance_case(df, clusters):
df['Cluster'] = clusters['Cluster']
df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats',
......