Commit bf4eb4cc authored by thecml's avatar thecml
Browse files

updated code for new io, added models to git

parent 26f5f2f7
......@@ -3,7 +3,7 @@ import numpy as np
import pandas as pd
import paths as pt
from tools import file_reader, file_writer, data_loader
from utility import metrics
from utility.settings import load_settings
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import recall_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
......@@ -11,7 +11,8 @@ from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
from pathlib import Path
import yaml
import shutil
from io import BytesIO
CASES = ["Complete", "Compliance", "Fall"]
DATASET_VERSION = 'emb'
......@@ -19,29 +20,39 @@ DATASET_VERSION = 'emb'
def main():
for case in CASES:
if case == "Complete":
with open(Path.joinpath(pt.CONFIGS_DIR, "complete_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
settings = load_settings(pt.CONFIGS_DIR, "complete.yaml")
file_name = f'complete_{DATASET_VERSION}.csv'
dl = data_loader.CompleteDataLoader(file_name, settings).load_data()
dl = data_loader.CompleteDataLoader(pt.PROCESSED_DATA_DIR,
file_name,
settings).load_data()
X, y = dl.prepare_data()
elif case == "Compliance":
with open(Path.joinpath(pt.CONFIGS_DIR, "compliance_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
settings = load_settings(pt.CONFIGS_DIR, "compliance.yaml")
file_name = f'compliance_{DATASET_VERSION}.csv'
dl = data_loader.ComplianceDataLoader(file_name, settings).load_data()
dl = data_loader.ComplianceDataLoader(pt.PROCESSED_DATA_DIR,
file_name,
settings).load_data()
X, y = dl.prepare_data()
else:
with open(Path.joinpath(pt.CONFIGS_DIR, "fall_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
settings = load_settings(pt.CONFIGS_DIR, "fall.yaml")
file_name = f'fall_{DATASET_VERSION}.csv'
dl = data_loader.FallDataLoader(file_name, settings).load_data()
dl = data_loader.FallDataLoader(pt.PROCESSED_DATA_DIR,
file_name,
settings).load_data()
X, y = dl.prepare_data()
model = RandomForestClassifier(n_estimators=200,
class_weight="balanced",
random_state=0)
model.fit(X, y)
file_writer.write_joblib(model, pt.MODELS_DIR, f'{case.lower()}_rf.joblib')
file_path = pt.MODELS_DIR
file_name = f'{case.lower()}_rf.joblib'
with open(Path.joinpath(file_path, file_name), 'wb') as fd:
outfile = BytesIO()
file_writer.write_joblib(model, outfile)
outfile.seek(0)
shutil.copyfileobj(outfile, fd)
if __name__ == '__main__':
main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment