Skip to content
Snippets Groups Projects
Commit f5b21330 authored by Christian Marius Lillelund's avatar Christian Marius Lillelund
Browse files

added config for api

parent 3b34312b
No related branches found
No related tags found
No related merge requests found
---
# Settings for api -------------------------------------------------
#
ats_resolution: 10
\ No newline at end of file
......@@ -6,6 +6,7 @@ import csv
import joblib
import pandas as pd
import io
from pathlib import Path
from typing import List, Optional
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
......@@ -15,8 +16,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
ATS_RESOLUTION = 10
import yaml
app = FastAPI(title='AIR API', version='0.1',
description='A simple API that classifies citizens based on data')
......@@ -165,6 +165,9 @@ def refresh(Authorize: AuthJWT = Depends()):
@app.post('/predict', response_model=OutputData, tags=["ai"])
def predict(incoming_data: InputData):
settings = load_settings('config.yaml')
ats_resolution = settings['ats_resolution']
data = incoming_data.dict()
incoming_ats = [x.strip(' ') for x in data['Ats'].split(",")]
......@@ -173,16 +176,16 @@ def predict(incoming_data: InputData):
if all(x in list(ats_df['ats_id']) for x in incoming_ats) != True:
raise HTTPException(status_code=400, detail="Ats not found, check ats list")
df = prepare_data(data)
arguments = generate_arguments(df)
df = prepare_data(data, ats_resolution)
arguments = generate_arguments(df, ats_resolution)
complete_model = read_joblib(f'complete_xgboost.joblib')
compliance_model = read_joblib(f'compliance_xgboost.joblib')
fall_model = read_joblib(f'fall_xgboost.joblib')
df_for_complete = add_embeddings(df.copy(), 'complete')
df_for_compliance = add_embeddings(df.copy(), 'compliance')
df_for_fall = add_embeddings(df.copy(), 'fall')
df_for_complete = add_embeddings(df.copy(), 'complete', ats_resolution)
df_for_compliance = add_embeddings(df.copy(), 'compliance', ats_resolution)
df_for_fall = add_embeddings(df.copy(), 'fall', ats_resolution)
complete_prob = complete_model.predict_proba(df_for_complete).flatten()[1]
fall_prob = fall_model.predict_proba(df_for_fall).flatten()[1]
......@@ -210,15 +213,15 @@ def predict(incoming_data: InputData):
'Arguments': arguments
}
def add_embeddings(df: pd.DataFrame, case: str) -> pd.DataFrame:
for i in range(1, ATS_RESOLUTION+1):
def add_embeddings(df: pd.DataFrame, case: str, ats_resolution: int) -> pd.DataFrame:
for i in range(1, ats_resolution+1):
embedding = read_embedding(f'{case}_{i}Ats.csv')
column = f'{i}Ats'
df[column] = df[column].replace(to_replace=embedding)
df[column] = pd.to_numeric(df[column])
return df
def generate_arguments(df: pd.DataFrame):
def generate_arguments(df: pd.DataFrame, ats_resolution: int):
arguments = list()
gender = 'kvinde' if df.iloc[0].Gender_Male == 0 else 'mand'
......@@ -228,7 +231,7 @@ def generate_arguments(df: pd.DataFrame):
arguments.append(f'Personen har {int(df.iloc[0].NumberAts)} hjælpemidle(r) i hjemmet')
ats_arguments = list()
for i in range(1, ATS_RESOLUTION+1):
for i in range(1, ats_resolution+1):
ats_name = get_ats_name_from_hmi(df.iloc[0][f'{i}Ats'])
if ats_name != "":
ats_arguments.append(f'Personen har et hjælpemiddel af typen {ats_name} som sit {i}. hjælpemiddel')
......@@ -237,25 +240,31 @@ def generate_arguments(df: pd.DataFrame):
arguments.extend(ats_arguments)
return arguments
def load_settings(file_name):
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(Path.joinpath(dir_path, file_name), 'r') as stream:
settings = yaml.safe_load(stream)
return settings
def get_ats_name_from_hmi(ats_id: str):
ats = read_csv('ats.csv')
if ats_id in ats:
return ats[ats_id]
return ""
def prepare_data(data: dict) -> pd.DataFrame:
def prepare_data(data: dict, ats_resolution: int) -> pd.DataFrame:
new_data = {k: [v] for k, v in data.items()}
new_data_df = pd.DataFrame.from_dict(new_data)
new_data_df['NumberAts'] = len(new_data_df['Ats'][0].split(","))
df = split_categorical_columns(new_data_df, col='Ats', tag='Ats',
resolution=ATS_RESOLUTION)
resolution=ats_resolution)
df['Gender_Male'] = float(any(df['Gender'] == 1))
df['Gender_Female'] = float(any(df['Gender'] == 0))
df = df.drop(['Gender'], axis=1)
cols_ats = [str(i)+'Ats' for i in range(1, ATS_RESOLUTION+1)]
cols_ats = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
header_list = ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod', 'NumberAts'] + cols_ats
df = df.reindex(columns=header_list)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment