import pandas as pd
import numpy as np
from pandas.tseries.offsets import DateOffset

def make_fall_label(df: pd.DataFrame, settings: dict):
    # Get DigiRehab falls
    digi_falls = df[['CitizenId', 'NeedsReason', 'PhysicsReason', 'EndDate']].fillna('Ingen')
    digi_falls = digi_falls[digi_falls['NeedsReason'].str.contains("Fald/uheld")
                            | digi_falls['PhysicsReason'].str.contains("Fald/uheld")]
    
    # Make target by annotating falls in the risk period
    fall_period = settings['fall_period_months']
    df['Fall'] = df[['CitizenId', 'EndDate']].apply(lambda x: annotate_falls(x, digi_falls, fall_period), axis=1)
    
    return df

def make_complete_label(df: pd.DataFrame, settings: dict) -> pd.DataFrame:
    df = accumulate_screenings(df, settings)
    
    # Set first screening as baseline
    df.loc[df['NumberScreening'] == 0, 'Baseline'] = 1
    
    # Calculate if citizens complete their sessions
    grp = df.groupby(['CitizenId', 'NumberSplit', 'NumberSession'])['HasCompletedSession']
    df['Complete'] = grp.transform(lambda x: 1 if np.max(x) > 0 else 0)
    df = df.astype({"Complete": int})

    # Select only the baseline screenings
    df = df.loc[df['Baseline'] == 1]

    # Reset index
    df = df.reset_index()

    return df

def make_compliance_label(df: pd.DataFrame, settings: dict) -> pd.DataFrame:
    df = accumulate_screenings(df, settings)
    
    # Set first screening as baseline
    df.loc[df['NumberScreening'] ==  0, 'Baseline'] = 1
    
    # Calculate if citizens reached compliance in their sessions
    grp_compliance = df.groupby(['CitizenId', 'NumberSplit', 'NumberSession'])['GotComplianceInSession']
    df['Compliance'] = grp_compliance.transform(lambda x: 1 if np.max(x) > 0 else 0)
    df = df.astype({"Compliance": int})
    
    # Calculate if citizens completed their sessions
    grp_complete = df.groupby(['CitizenId', 'NumberSplit', 'NumberSession'])['HasCompletedSession']
    df['Complete'] = grp_complete.transform(lambda x: 1 if np.max(x) > 0 else 0)
    df = df.astype({"Complete": int})

    # Select only the baseline screenings
    df = df.loc[df['Baseline'] == 1]
    
    # Select only those that complete
    df = df.loc[df['Complete'] == 1]
    
    # Reset index
    df = df.reset_index()

    return df

def make_alarm_label(df, settings):    
    df = accumulate_screenings(df, settings)

    # Set first screening as baseline
    df.loc[df['NumberScreening'] == 0, 'Baseline'] = 1

    # Calculate if citizens got an alarm during session
    grp = df.groupby(['CitizenId', 'NumberSplit', 'NumberSession'])['GotAlarmInSession']
    df['Alarm'] = grp.transform(lambda x: 1 if np.max(x) > 0 else 0)
    df = df.astype({"Alarm": int})

    # Select only the baseline screenings
    df = df.loc[df['Baseline'] == 1]

    # Reset index
    df = df.reset_index()

    return df

def annotate_falls(row, digi_db, fall_period):
    citizen_id = row['CitizenId']
    current_date = pd.Timestamp(row['EndDate'])
    end_date = current_date + DateOffset(months=fall_period)
    
    digi_db['EndDate'] = pd.to_datetime(digi_db['EndDate'])
    
    timespan_digi_falls = digi_db.loc[(digi_db['EndDate'] >= current_date)
                                      & (digi_db['CitizenId'] == citizen_id)
                                      & (digi_db['EndDate'] <= end_date)]
    
    if len(timespan_digi_falls) > 0:
        return 1
    return 0

def accumulate_screenings(df: pd.DataFrame, settings: dict) -> pd.DataFrame:
    for group_name, _ in df.groupby(['CitizenId', 'NumberSplit']):
        number_session = 0
        cumsum_weeks = 0
        cumsum_training = 0
        citizen_df = df.loc[df['CitizenId'] == group_name[0]]
        items_weeks = citizen_df['NumberWeeks'].iteritems()
        items_training = citizen_df['NumberTraining'].iteritems()
        items_mean_evaluation = citizen_df['MeanEvaluation'].iteritems()
        items_ats = citizen_df['Ats'].iteritems()
        
        for (row_week, row_train, row_mean_evaluation, row_ats) in zip(
            items_weeks, items_training, items_mean_evaluation, items_ats):
            df.loc[row_week[0], 'NumberSession'] = number_session
            cumsum_weeks += row_week[1]
            cumsum_training += row_train[1]
            if '222718' in row_ats[1].split(','):
                df.loc[row_week[0], 'GotAlarmInSession'] = 1
            else:
                df.loc[row_week[0], 'GotAlarmInSession'] = 0
            cond_weeks = cumsum_weeks >= settings['threshold_weeks']
            cond_training = cumsum_training >= settings['threshold_training']
            if cond_weeks and cond_training:
                cumsum_weeks = 0
                cumsum_training = 0
                number_session += 1
                df.loc[row_week[0], 'HasCompletedSession'] = 1
                df.loc[row_week[0], 'Baseline'] = 1
                if (row_mean_evaluation[1] > 4):
                    df.loc[row_week[0], 'GotComplianceInSession'] = 1
                else:
                    df.loc[row_week[0], 'GotComplianceInSession'] = 0
            else:
                df.loc[row_week[0], 'HasCompletedSession'] = 0
                df.loc[row_week[0], 'Baseline'] = 0
                df.loc[row_week[0], 'GotComplianceInSession'] = 0
                
    return df