Commit bf4eb4cc authored by thecml's avatar thecml
Browse files

updated code for new io, added models to git

parent 26f5f2f7
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import pandas as pd import pandas as pd
import paths as pt import paths as pt
from tools import file_reader, file_writer, data_loader from tools import file_reader, file_writer, data_loader
from utility import metrics from utility.settings import load_settings
from sklearn.metrics import accuracy_score, precision_score from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import recall_score, average_precision_score from sklearn.metrics import recall_score, average_precision_score
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
...@@ -11,7 +11,8 @@ from sklearn.model_selection import train_test_split ...@@ -11,7 +11,8 @@ from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb import xgboost as xgb
from pathlib import Path from pathlib import Path
import yaml import shutil
from io import BytesIO
CASES = ["Complete", "Compliance", "Fall"] CASES = ["Complete", "Compliance", "Fall"]
DATASET_VERSION = 'emb' DATASET_VERSION = 'emb'
...@@ -19,29 +20,39 @@ DATASET_VERSION = 'emb' ...@@ -19,29 +20,39 @@ DATASET_VERSION = 'emb'
def main(): def main():
for case in CASES: for case in CASES:
if case == "Complete": if case == "Complete":
with open(Path.joinpath(pt.CONFIGS_DIR, "complete_emb.yaml"), 'r') as stream: settings = load_settings(pt.CONFIGS_DIR, "complete.yaml")
settings = yaml.safe_load(stream)
file_name = f'complete_{DATASET_VERSION}.csv' file_name = f'complete_{DATASET_VERSION}.csv'
dl = data_loader.CompleteDataLoader(file_name, settings).load_data() dl = data_loader.CompleteDataLoader(pt.PROCESSED_DATA_DIR,
file_name,
settings).load_data()
X, y = dl.prepare_data() X, y = dl.prepare_data()
elif case == "Compliance": elif case == "Compliance":
with open(Path.joinpath(pt.CONFIGS_DIR, "compliance_emb.yaml"), 'r') as stream: settings = load_settings(pt.CONFIGS_DIR, "compliance.yaml")
settings = yaml.safe_load(stream) file_name = f'compliance_{DATASET_VERSION}.csv'
file_name = f'compliance_{DATASET_VERSION}.csv' dl = data_loader.ComplianceDataLoader(pt.PROCESSED_DATA_DIR,
dl = data_loader.ComplianceDataLoader(file_name, settings).load_data() file_name,
settings).load_data()
X, y = dl.prepare_data() X, y = dl.prepare_data()
else: else:
with open(Path.joinpath(pt.CONFIGS_DIR, "fall_emb.yaml"), 'r') as stream: settings = load_settings(pt.CONFIGS_DIR, "fall.yaml")
settings = yaml.safe_load(stream)
file_name = f'fall_{DATASET_VERSION}.csv' file_name = f'fall_{DATASET_VERSION}.csv'
dl = data_loader.FallDataLoader(file_name, settings).load_data() dl = data_loader.FallDataLoader(pt.PROCESSED_DATA_DIR,
file_name,
settings).load_data()
X, y = dl.prepare_data() X, y = dl.prepare_data()
model = RandomForestClassifier(n_estimators=200, model = RandomForestClassifier(n_estimators=200,
class_weight="balanced", class_weight="balanced",
random_state=0) random_state=0)
model.fit(X, y) model.fit(X, y)
file_writer.write_joblib(model, pt.MODELS_DIR, f'{case.lower()}_rf.joblib')
file_path = pt.MODELS_DIR
file_name = f'{case.lower()}_rf.joblib'
with open(Path.joinpath(file_path, file_name), 'wb') as fd:
outfile = BytesIO()
file_writer.write_joblib(model, outfile)
outfile.seek(0)
shutil.copyfileobj(outfile, fd)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment