# -*- coding: utf-8 -*-
"""
Created on Fri Apr 22 12:58:22 2022

@author: annik
"""

## Imports
import os, numpy, pandas
from Bio.PDB import *
from sklearn.metrics import confusion_matrix
from sklearn import metrics
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import RepeatedKFold
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
from collections import Counter
import random
import xgboost as xgb
import seaborn as sns

# set directory
full_path = os.path.dirname(os.path.abspath(__file__))
full_path.replace("\\","/")
base_root = full_path+str('/')

class set_directory:
    def __init__(self, base_root):

        self.python_script = base_root+'python_scripts/'
        self.repos = base_root+'repository/data/'
        self.base_name = 'MDR3'
        self.MDR3_dir = base_root+'MDR3/'
    
        if not os.path.exists(base_root+'/images'):
            os.makedirs(base_root+'/images')
        
class Preprocessing:
    def __init__(self): 
        
        self.csv_file = base_root+'final_dataset.csv'
        
    # check distributions of amino acids of wildtype and mutation as well as positions, choose a testset similar to entire dataset
    def setting_up(self):
        '''
        Reads in automatically final dataset file, replaces 99999 error values and returns:
        - dataset
        - list of mutations
        - dictionary of mutation effect (mutation as value, effect as key)
        - list of amino acids as reference
        '''
        df = pandas.read_csv(self.csv_file, sep = ',')
        #replace the error values
        df.replace(to_replace = 99999, value = 0, inplace=True)
        #save the dataframe again
        df.to_csv(self.csv_file, sep = ',', index = False)
        
        mutations = [x for x in df['Mutation'].to_list() if str(x) != 'nan']
        # create dictionary with each mutation and its pathogenicity as key to set maximum of half pathogenic in testset
        mut_effect = {}
        for mut in mutations:
            effect = df.loc[df['Mutation'] == mut, 'Pathogenicity'].iloc[0]
            mut_effect[mut] = effect
        # reference amino acid list
        amino_acids = ['R', 'H', 'K', 'D', 'E', 'S', 'T', 'N', 'Q', 'C', 'G', 'P', 'A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W']
        
        return(df, mutations, mut_effect, amino_acids)

    def get_ratio(self, datapoints):
        ''' 
        Call with list of datapoints (e.g. mutations). Returns
        - dictionary of wildtype amino acids with ratios over datapoints as values
        - dictionary of mutant amino acids with ratios over datapoints as values
        - dictionary of positions classified in bins of 100 with ratios over datapoints as values
        '''
        _, _, mut_effect, amino_acids = self.setting_up()
        
        # from list of mutation datapoints, derive the distribution to compare to 
        wt_aa = []
        var_aa = []
        pos_all = []

        for point in datapoints:
            wt_aa.append(point[0])
            var_aa.append(point[-1])
            pos_all.append(point[1:-1])

        # create a dictionary with wildtype amino acid as key and ratio over all datapoints as value
        c_wt = Counter(wt_aa)
        wt_aa_dict = {}
        for aa in amino_acids:
            if aa in c_wt.keys():
                wt_aa_dict[aa] = round(c_wt[aa] * 100 / len(datapoints), 2)
            else:
                wt_aa_dict[aa] = 0

        # create a dictionary with variant amino acid as key and ratio over all datapoints as value
        c_var = Counter(var_aa)
        var_aa_dict = {}
        for aa in amino_acids:
            if aa in c_var.keys():
                var_aa_dict[aa] = round(c_var[aa] * 100 / len(datapoints), 2)
            else:
                var_aa_dict[aa] = 0

        # create a dictionary with the positions in bins of 100 and the count as value
        positions = Counter(float(x) // 100 for x in pos_all)
        pos_ratio = {}
        bins = ['0-100','100-200','200-300','300-400','400-500','500-600','600-700','700-800','800-900',
                    '900-1000','1000-1100','1100-1200','1200-1300']
        for index, pos in enumerate(positions.keys()):
            pos_ratio[bins[index]] = round(positions[pos] *100 / len(datapoints), 2)

        return(wt_aa_dict, var_aa_dict, pos_ratio, mut_effect)

    def rmsd(self, test, target):
        ''' 
        Takes in two dictionaries and converts the values to lists. 
        Attention: Make sure that the keys are ordered the same way as it is not checked!
        Calculates the RMSD if the lists have the same length
        '''
        testing = list(test.values())
        targeting = list(target.values())
        if len(testing) == len(targeting):
            rmsd = (numpy.sqrt(1/len(testing) * sum([(x-y) ** 2 for x,y in zip(testing, targeting)])))
        else:
            print('Please input lists of same lengths')
        
        return(rmsd)

    def get_testset(self, datapoint):
        ''' 
        Input: list of mutation (datapoint).
        Automatically put 40 mutations from mutation input list (datapoint) into a testset list. 
        Compare to overall amino acid distribution from get_ratio function and optimize based on RMSD between 
        wildtype amino acid and mutant amino acid distributions.
        Includes plotting amino acid and position distributions in testset compared to overall dataset.
        Returns:
        - testset (list of mutations in the testset)
        '''
        # get ratio of original datapoints set
        wt_aa_dict, var_aa_dict, orig_positions, mut_effect = self.get_ratio(datapoint)

        # make a copy of the datapoints first to remove only from there
        datapoints = list(datapoint)

        size_testset = 40
        max_class = size_testset / 2
        testset = []
        seed = 28 
        random.seed(seed)

        # first draw 10 mutations randomly into the test set
        while len(testset) < 10:
            new_testpoint = random.choice(datapoints) 
            if new_testpoint not in testset:
                testset.append(new_testpoint) 

        while len(testset) < size_testset:
            # get the old testset distribution dictionary for comparisom
            old_testset, old_testset_var, _, _ = self.get_ratio(testset)
            new_testpoint = random.choice(datapoints) 
            if new_testpoint not in testset:
                testset.append(new_testpoint) 
            num_ben = 0
            num_pat = 0
            for item in testset:
                if mut_effect[item] == 0: 
                    num_ben += 1
                elif mut_effect[item] == 1:
                    num_pat += 1
                else:
                    break
            if num_ben > max_class or num_pat > max_class:
                testset.remove(new_testpoint) 
            else:
                testset_dict, testset_dict_var, _, _ = self.get_ratio(testset)
                old_rmsd = self.rmsd(old_testset, wt_aa_dict)
                new_rmsd = self.rmsd(testset_dict, wt_aa_dict)
                old_rmsd_var = self.rmsd(old_testset_var, var_aa_dict)
                new_rmsd_var = self.rmsd(testset_dict_var, var_aa_dict)
                if (old_rmsd >= new_rmsd) and (old_rmsd_var >= new_rmsd_var):
                    pass
                elif (old_rmsd < new_rmsd) and (old_rmsd_var < new_rmsd_var):
                    testset.remove(new_testpoint)
                elif (old_rmsd >= new_rmsd) and (old_rmsd_var < new_rmsd_var):
                    if (new_rmsd_var - old_rmsd_var) > 0.1:
                        testset.remove(new_testpoint)
                elif (old_rmsd < new_rmsd) and (old_rmsd_var >= new_rmsd_var):
                    if (new_rmsd - old_rmsd) > 0.1:
                        testset.remove(new_testpoint)
                else:
                    print('ERROR: error in obtaining the testset')   
    
        # sorting the dictionaries by alphabet
        wt_testset = {}
        for key in sorted(testset_dict.keys()):
            wt_testset[key] = testset_dict[key]
        
        wt_dataset = {}
        for key in sorted(wt_aa_dict.keys()):
            wt_dataset[key] = wt_aa_dict[key]
            
        var_testset = {}
        for key in sorted(testset_dict_var.keys()):
            var_testset[key] = testset_dict_var[key]
        
        var_dataset = {}
        for key in sorted(var_aa_dict.keys()):
            var_dataset[key] = var_aa_dict[key]
       
        return(testset, wt_testset, wt_dataset, var_testset, var_dataset)
    
    def diff(self, testing, targeting):
        difference = testing - targeting
        return difference
    
    def diff_visualization(self, wt_testset, wt_dataset, var_testset, var_dataset):       
        index = range(len([x for x in wt_testset.keys()]))
        diff_wt = []
        for aa in wt_testset.keys():
             diff_wt.append(self.diff(wt_testset[aa], wt_dataset[aa]))
        
        diff_var = []
        for aa in var_testset.keys():
             diff_var.append(self.diff(var_testset[aa], var_dataset[aa]))
             
        fig, ax = plt.subplots(1, figsize=(9,6), dpi=150)
        ax.set_xticks(index)
        ax.set_xticklabels([x for x in wt_testset.keys()])
        ax.set_xlim([-1.2, 20.2])
        ax.set_ylim([-6, 6])
        ax.set_ylabel('% Difference in amino acid distribution', fontsize=14)
        ax.plot(index, len(index)*[0], 'k--')
        ax.plot(index, diff_wt, 'g*', label='Wildtype amino acid')
        ax.plot(index, diff_var, 'bo', label = 'Variant amino acid')
        
        for ind, value in enumerate(diff_wt):
            ax.plot( [ind, ind], [0, value], 'k--', linewidth=.4)
        for ind, value in enumerate(diff_var):
            ax.plot( [ind, ind], [0, value], 'k--', linewidth=.4)
        
        # fill the upper half with grey background
        upper_limit = [5.8] *22
        lower_limit = [-5.8] *22
        range_fill = range(-1,21)
        ax.fill_between(range_fill, upper_limit, color='.8')
        ax.fill_between(range_fill, lower_limit, color='.9')
               
        for label in (ax.get_xticklabels() + ax.get_yticklabels()):
            label.set_fontsize(12)
        
        plt.tight_layout()
        plt.savefig(base_root+'images/Agreement_datasets.png', dpi=300)
        plt.show()
        
    def Run(self):
        
        # setting matplotlib defaults
        plt.rcParams.update(plt.rcParamsDefault)
        
        df, mutations, mut_effect, amino_acids = self.setting_up()
        testset, wt_testset, wt_dataset, var_testset, var_dataset = self.get_testset(mutations)
        self.diff_visualization(wt_testset, wt_dataset, var_testset, var_dataset)
        
        # place the rows corresponding to the testset mutations at the end of the dataframe
        target_rows = []
        for i in testset:
            ind = df[df['Mutation'] == i].index.values.tolist()[0]
            target_rows.append(ind)
        
        a = df.iloc[[i for i in df.index if i not in target_rows], :]
        b = df.iloc[target_rows, :]
        dataset = pandas.concat([a, b]).reset_index(drop=True)
        
        dataset.to_csv(base_root+'final_dataset_for_ML.csv', sep = ',', index = False)
        return(testset)


class Visualize_corr:
    def __init__(self):       
        self.csv_file = base_root+'final_dataset_for_ML.csv'
        
    def correlation(self):
        # read in the dataset
        df = pandas.read_csv(self.csv_file, sep=',')
        df.drop(columns='Mutation', inplace=True)

        # setting matplotlib defaults
        plt.rcParams.update(plt.rcParamsDefault)
        
        # Correlation analysis heatmap
        fig, ax = plt.subplots(1, figsize=(20,20), dpi=300)
        corr = df.corr(method='spearman')
        sns.set(font_scale=2.5)
        g = sns.heatmap(round(corr,2), vmin=-1, vmax=1, annot=True, fmt='.2f', cmap='coolwarm', annot_kws={'size':16})
        g.set_xticklabels(g.get_xmajorticklabels(), fontsize= 20)
        g.set_yticklabels(g.get_ymajorticklabels(), fontsize= 20)
        
        plt.tight_layout()
        plt.savefig(base_root+'images/Corr_matrix_spearman.png', dpi=300)
        
        # Calculate the RMS
        # get values of dataframe
        vals = corr.values
        lower_triangular = vals[numpy.tril_indices(vals.shape[0], -1)]
        # get rmsd value (how far deviates the average value from zero)
        rms = numpy.sqrt(numpy.mean(numpy.square(lower_triangular)))
        
        return(corr, rms)

class Method:
    def __init__(self):         
        self.repos = base_root+'repository/data/'
        self.csv_file = base_root+'final_dataset_for_ML.csv'

        # setting matplotlib defaults
        plt.rcParams.update(plt.rcParamsDefault)        

    def roc_curve(self, y_true, y_prob, thresholds):
        '''
        Input: y_true values, y_predicted values and thresholds as list.
        Returns:
        - false-positive rate value(s) for given thresholds in list
        - true-positive rate value(s) for given thresholds in list
        '''
        fpr = []
        tpr = []

        for threshold in thresholds:

            y_pred = numpy.where(y_prob >= threshold, 1, 0)

            fp = numpy.sum((y_pred == 1) & (y_true == 0))
            tp = numpy.sum((y_pred == 1) & (y_true == 1))

            fn = numpy.sum((y_pred == 0) & (y_true == 1))
            tn = numpy.sum((y_pred == 0) & (y_true == 0))

            fpr.append(fp / (fp + tn))
            tpr.append(tp / (tp + fn))

        return [fpr, tpr]

    def plot_confusion_matrix(self, cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True, export=None):
    
        """
        given a sklearn confusion matrix (cm), make a nice plot
    
        Arguments
        ---------
        cm:           confusion matrix from sklearn.metrics.confusion_matrix
    
        target_names: given classification classes such as [0, 1, 2]
                      the class names, for example: ['high', 'medium', 'low']
    
        title:        the text to display at the top of the matrix
    
        cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                      see http://matplotlib.org/examples/color/colormaps_reference.html
                      plt.get_cmap('jet') or plt.cm.Blues
    
        normalize:    If False, plot the raw numbers
                      If True, plot the proportions
    
        Usage
        -----
        plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                                  # sklearn.metrics.confusion_matrix
                              normalize    = True,                # show proportions
                              target_names = y_labels_vals,       # list of names of the classes
                              title        = best_estimator_name) # title of graph
    
        Citiation
        ---------
        http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    
        """
     
        import matplotlib.pyplot as plt
        import numpy as np
        import itertools
    
        accuracy = np.trace(cm) / np.sum(cm).astype('float')
        misclass = 1 - accuracy
    
        if cmap is None:
            cmap = plt.get_cmap('Blues')
    
        plt.figure(figsize=(8, 6))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
    
        if target_names is not None:
            tick_marks = np.arange(len(target_names))
            plt.xticks(tick_marks, target_names, rotation=45)
            plt.yticks(tick_marks, target_names)
    
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    
        thresh = cm.max() / 1.5 if normalize else cm.max() / 2
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            if normalize:
                plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")
            else:
                plt.text(j, i, "{:,}".format(cm[i, j]),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.ylabel('True label', fontsize=14)
        plt.xlabel('Predicted label\naccuracy={:0.2f}; misclass={:0.2f}'.format(accuracy, misclass), fontsize=14)

        if export != None:
            plt.savefig(base_root+export)


    def MachineLearning(self):
        # read in the dataset
        df = pandas.read_csv(self.csv_file, sep=',')
        
        ## Define training and testset
        X_train = df.iloc[:-40, 1:-1]
        y_train = df.iloc[:-40, -1]

        X_test = df.iloc[-40:, 1:-1].reset_index(drop=True)
        y_test = df.iloc[-40:, -1].reset_index(drop=True)

        sm = SMOTE(random_state=42)
        X_sm, y_sm = sm.fit_resample(X_train, y_train)
        
        # define classifier xgboost
        XGB = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', max_depth=3, subsample=0.6, learning_rate= 0.02) 
               
        # check performance in repeated k fold 
        k = 3
        n_repeat = 5
        rkf = RepeatedKFold(n_splits = k, n_repeats = n_repeat, random_state= 241285)
        
        plt.rc('font', size=16)
        plt.rc('axes', labelsize=16)
        
        fig, ax = plt.subplots(1, figsize= (12, 7), dpi=300)
        for i, (train_index, test_index) in enumerate(rkf.split(X_sm, y_sm)):
            X_training, X_te = X_sm.iloc[train_index, :], X_sm.iloc[test_index, :]
            y_training, y_te = y_sm[train_index], y_sm[test_index]

            # XGBoost
            XGB.fit(X_training, y_training)
            print("XGB score accuracy: ", XGB.score(X_te, y_te))

            viz = RocCurveDisplay.from_estimator(XGB, X_te, y_te, ax = ax, name = 'k fold {}'.format(i+1), response_method='predict_proba', 
                                 color= 'k', linestyle=(0,(1,1)), alpha= 0.6, lw=1)
            
        ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='black', label='Chance', alpha=.8)
        
        XGB.fit(X_sm, y_sm)
        RocCurveDisplay.from_estimator(XGB, X_test, y_test, ax = ax, name='XGBoost', 
                       response_method='predict_proba', lw=2, color='k', linestyle='solid')
        
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xlabel('False Positive Rate', fontsize=16)
        ax.set_ylabel('True Positive Rate', fontsize=16)
        
        plt.tight_layout()
        plt.savefig(base_root+'images/XGBoost.png')
        plt.show()
        
        # Plot confusion matrix
        cm_xgb = confusion_matrix(df.iloc[:, -1], XGB.predict(df.iloc[:, 1:-1]))
        cm_x_test = confusion_matrix(y_test, XGB.predict(X_test))
        self.plot_confusion_matrix(cm_xgb, [0, 1], normalize=False, 
                                   title='XGBoost on entire dataset', cmap=None, export='images/CM_dataset')
        self.plot_confusion_matrix(cm_x_test, [0, 1], normalize=False, 
                                   title='XGBoost on testset', cmap=None, export='images/CM_testset')
        plt.show()
         
       
        return(XGB, X_train, X_test, y_train, y_test, X_sm, y_sm, df)
    

    def feature_visualization(self, XGB, X_sm, y_sm): 
        from sklearn.inspection import permutation_importance
    
        # store permutation for xgb on training data - see how much the model relies on each feature during training
        result_xgb = permutation_importance(XGB, X_sm, y_sm, n_repeats= 10, random_state= 42) 
        perm_sorted_idx_xgb = result_xgb.importances_mean.argsort()

        # store tree based feature importance for xgb on training data
        tree_importance_sorted_idx_xgb = numpy.argsort(XGB.feature_importances_)
        tree_indices_xgb = numpy.arange(0, len(XGB.feature_importances_)) + 0.5
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,8), constrained_layout=True)
        ax1.barh(tree_indices_xgb, XGB.feature_importances_[tree_importance_sorted_idx_xgb], height=0.7)
        ax1.set_yticks(tree_indices_xgb)
        ax1.set_yticklabels(X_sm.columns[tree_importance_sorted_idx_xgb])
        ax1.set_ylim((0, len(XGB.feature_importances_)))
        ax1.set_title('Tree based feaure importance')
        ax2.boxplot(result_xgb.importances[perm_sorted_idx_xgb].T, vert=False, labels=X_sm.columns[perm_sorted_idx_xgb])
        ax2.set_title('Permutation importance')
        fig.suptitle('Feature importance in XGBoost model')
        
        plt.savefig(base_root+'images/Feature_importance_original.png', dpi=300)
        plt.show()
        
        # removing the 4 least important features
        col_out = ['RSA', 'I_Mutant_stsign', 'ss_effect', 'I_Mutant_deltaG']
        # read in the dataset
        df = pandas.read_csv(self.csv_file, sep=',')
        df.drop(axis=1, columns=col_out, inplace= True)
        
        ## Define training and testset
        X_train = df.iloc[:-40, 1:-1]
        y_train = df.iloc[:-40, -1]

        X_test = df.iloc[-40:, 1:-1].reset_index(drop=True)
        y_test = df.iloc[-40:, -1].reset_index(drop=True)

        sm = SMOTE(random_state=42)
        X_sm_new, y_sm_new = sm.fit_resample(X_train, y_train)
    
        # define classifier xgboost for selected features
        XGBsel = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', max_depth=3, subsample=0.6, learning_rate= 0.02)# max_depth= 3, subsample=0.5, learning_rate= 0.02) #subsample 0.5, max 2
        
        # check performance in repeated k fold 
        k = 3
        n_repeat = 5
        rkf = RepeatedKFold(n_splits = k, n_repeats = n_repeat, random_state= 241285)
        
        plt.rc('axes', labelsize=14)
        
        fig, ax = plt.subplots(1, figsize=(13, 7), dpi=300)
        for i, (train_index, test_index) in enumerate(rkf.split(X_sm_new, y_sm_new)):
            X_training, X_te = X_sm_new.iloc[train_index, :], X_sm_new.iloc[test_index, :]
            y_training, y_te = y_sm_new[train_index], y_sm_new[test_index]
            
            # XGBoost
            XGBsel.fit(X_training, y_training)
            print("XGBsel score accuracy: ", XGBsel.score(X_te, y_te))

            viz = RocCurveDisplay.from_estimator(XGBsel, X_te, y_te, ax = ax, response_method='predict_proba', 
                                 drop_intermediate=False, name = 'k fold {}'.format(i+int(1)), alpha= 0.6, lw=1, color= 'k')#, linestyle=(0,(1,1)))

        ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='black', label='Chance', alpha=.8)
                
        XGBsel.fit(X_sm_new, y_sm_new)      
        RocCurveDisplay.from_estimator(XGBsel, X_test, y_test, ax = ax, response_method='predict_proba', name='Vasor', lw=4, 
                       color='green', linestyle= 'solid')
       
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(base_root+'images/Vasor_ROC.png', dpi=300)
        plt.show()
        
        # show feature importance plot again after selection
        result_xgb = permutation_importance(XGBsel, X_sm_new, y_sm_new, n_repeats= 10, random_state= 42)
        perm_sorted_idx_xgb = result_xgb.importances_mean.argsort()
        
        # store tree based feature importance for xgbsel on training data
        tree_importance_sorted_idx_xgb = numpy.argsort(XGBsel.feature_importances_)
        tree_indices_xgb = numpy.arange(0, len(XGBsel.feature_importances_)) + 0.5
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,8), constrained_layout=True)
        ax1.barh(tree_indices_xgb, XGBsel.feature_importances_[tree_importance_sorted_idx_xgb], height=0.7)
        ax1.set_yticks(tree_indices_xgb)
        ax1.set_yticklabels(X_sm_new.columns[tree_importance_sorted_idx_xgb])
        ax1.set_ylim((0, len(XGBsel.feature_importances_)))
        ax1.set_title('Tree based feaure importance')
        ax2.boxplot(result_xgb.importances[perm_sorted_idx_xgb].T, vert=False, labels=X_sm_new.columns[perm_sorted_idx_xgb])
        ax2.set_title('Permutation importance')
        fig.suptitle('Feature importance after feature selection')
        plt.savefig(base_root+'images/Feature_importance_Vasor.png', dpi=300)
        plt.show()
        
        # Plot confusion matrix
        cm_xgbsel = confusion_matrix(df.iloc[:, -1], XGBsel.predict(df.iloc[:, 1:-1]))
        self.plot_confusion_matrix(cm_xgbsel, [0, 1], normalize=False, title='Vasor on entire dataset', cmap=None)
        plt.tight_layout()
        plt.savefig(base_root+'images/Vasor_conf_mat_entire_dataset.png', dpi=300)
        
        cm_xgbseltest = confusion_matrix(y_test, XGBsel.predict(X_test))
        self.plot_confusion_matrix(cm_xgbseltest, [0, 1], normalize=False, title='Vasor on testset', cmap=None)
        plt.tight_layout()
        plt.savefig(base_root+'images/Vasor_conf_mat_testset.png', dpi=300)
        plt.show()
        
        # Save a dataframe with XGBoost Prediction value, classification and true value
        xgboost_df = pandas.DataFrame(columns = ['Mutation', 'Probability', 'Classification', 'True_Pathogenicity'])
        xgboost_df['Mutation'] = [x for x in df['Mutation'].tolist() if str(x) != 'nan']
        xgboost_df['Probability'] = XGBsel.predict_proba(df.iloc[:, 1:-1])[:,1]
        xgboost_df['Classification'] = XGBsel.predict(df.iloc[:, 1:-1])
        xgboost_df['True_Pathogenicity'] = [x for x in df['Pathogenicity']]
        
        xgboost_df.to_csv(self.repos+'Vasor_predictions_dataset.csv', sep=',', index=False)
        
        XGBsel.save_model(base_root+'Vasor_model.txt')
        
        return(XGBsel, X_sm_new, y_sm_new, xgboost_df, result_xgb)
    

    def extract_mutpred(self):
        '''
        Wrapper to read in MutPred2 predictions (precalculated from standalone version and written to .out file)
        Automatically reads in file and retrieves prediction value.
        -------
        Returns dataframe with columns Mutation, Probability and True_Pathogenicity.
        Writes dataframe also to repos/dataframes/MutPred_predictions.csv
        '''
        
        mutpred = pandas.read_csv(base_root+'MutPred_pred/MutPred_pred_MDR3.out')

        # get the dataframe of variants submitted for reference list of variants
        df = pandas.read_csv(self.csv_file)
        variants = [x for x in df['Mutation'].tolist() if str(x) != 'nan']
        true_path = [x for x in df['Pathogenicity'].tolist() if str(x) != 'nan']
        
        list_index = []
        path_score = []
        # get the index for the first occurence of variant
        for var in variants:
            for index in range(len(mutpred['Substitution'])): 
                    
                if mutpred['Substitution'][index] == var:
                    list_index.append(index)
                    path_score.append(mutpred['MutPred2 score'][index])
                    break
                
        new_df = pandas.DataFrame(data=[variants, path_score, true_path]).transpose()
        new_df.columns = ['Mutation', 'Probability', 'True_Pathogenicity']
        cols = ['Probability', 'True_Pathogenicity']
        new_df[cols] = new_df[cols].apply(pandas.to_numeric, errors= 'coerce')
        
        new_df.to_csv(self.repos+'dataframes/MutPred_predictions.csv', sep = ',', index = False)
        return(new_df)
    
    
    def compare_predictors(self, XGBsel):

        df = pandas.read_csv(self.csv_file, sep = ',')
        
        col_out = ['RSA', 'I_Mutant_stsign', 'ss_effect', 'I_Mutant_deltaG']
        df_sel = df.drop(axis=1, columns=col_out)
        
        ponp2 = pandas.read_csv(self.repos+'dataframes/PONP2_predictions.csv')
        eve = pandas.read_csv(self.repos+'dataframes/EVE_predictions.csv')
        polyphen = pandas.read_csv(self.repos+'dataframes/Polyphen2_predictions.csv')
        mutpred = self.extract_mutpred()
        
        # ROC curves
        threshold = numpy.arange(0.0, 1.0, 0.002)
        
        fig, ax = plt.subplots(1, squeeze=True, figsize= (9, 6))

        fprs_xgbsel, tprs_xgbsel = self.roc_curve(df_sel.loc[:, 'Pathogenicity'], XGBsel.predict_proba(df_sel.iloc[:, 1:-1])[:,1], threshold)
        fpr_xgbsel, tpr_xgbsel = self.roc_curve(df_sel.loc[:, 'Pathogenicity'], XGBsel.predict_proba(df_sel.iloc[:, 1:-1])[:,1], [0.5])
        auc_xgbsel = metrics.roc_auc_score(df_sel.loc[:, 'Pathogenicity'], XGBsel.predict_proba(df_sel.iloc[:, 1:-1])[:,1])
        ax.plot(fprs_xgbsel, tprs_xgbsel, linestyle='solid', color='green', label = 'Vasor (AUC = %.2f)' % auc_xgbsel, linewidth=4)
                
        fprs_eve, tprs_eve = self.roc_curve(eve.loc[:, 'True_Pathogenicity'], eve.loc[:, 'Probability'], threshold)
        fpr_eve, tpr_eve = self.roc_curve(eve.loc[:, 'True_Pathogenicity'], eve.loc[:, 'Probability'], [0.5])
        auc_eve = metrics.roc_auc_score(eve.loc[:, 'True_Pathogenicity'], eve.loc[:, 'Probability'])
        ax.plot(fprs_eve, tprs_eve,  color='black', label='EVE (AUC = %.2f)' % auc_eve, linewidth=4) 

        fprs_poly, tprs_poly = self.roc_curve(polyphen.loc[:, 'True_Pathogenicity'], polyphen.loc[:, 'Probability'], threshold)
        fpr_poly, tpr_poly = self.roc_curve(polyphen.loc[:, 'True_Pathogenicity'], polyphen.loc[:, 'Probability'], [0.5])
        fprs_poly.append(0)
        tprs_poly.append(0)
        auc_poly = metrics.roc_auc_score(polyphen.loc[:, 'True_Pathogenicity'], polyphen.loc[:, 'Probability'])
        ax.plot(fprs_poly, tprs_poly, color='blue', label='Polyphen2 (AUC = %.2f)' % auc_poly, linewidth=4) 
        
        fprs_p2, tprs_p2 = self.roc_curve(ponp2.loc[:, 'True_Pathogenicity'], ponp2.loc[:, 'Probability'], threshold)
        fpr_p2, tpr_p2 = self.roc_curve(ponp2.loc[:, 'True_Pathogenicity'], ponp2.loc[:, 'Probability'], [0.5])
        auc_p2 = metrics.roc_auc_score(ponp2.loc[:, 'True_Pathogenicity'], ponp2.loc[:, 'Probability'])
        ax.plot(fprs_p2, tprs_p2, color='grey', label='PONP2 (AUC = %.2f)' % auc_p2, linewidth=4) 
        
        # MutPred2
        fprs_mp, tprs_mp = self.roc_curve(mutpred.loc[:, 'True_Pathogenicity'], mutpred.loc[:, 'Probability'], threshold)
        auc_mp = metrics.roc_auc_score(mutpred.loc[:, 'True_Pathogenicity'], mutpred.loc[:, 'Probability'])
        ax.plot(fprs_mp, tprs_mp, color='red', label='MutPred2 (AUC = %.2f)' % auc_mp, linewidth=4) 
        
        ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='0.6', label='Chance', alpha=.8)
        ax.set_xlabel('False Positive Rate', fontsize=16)
        ax.set_ylabel('True Positive Rate', fontsize=16)
        plt.tight_layout()
        plt.savefig(base_root+'images/ROC_predictors.png', dpi=300)
        plt.show()
        
        # Precision recall curves
        fig, ax = plt.subplots(figsize= (9, 6))

        precision, recall, thresholds = precision_recall_curve(df_sel.iloc[:, -1], XGBsel.predict_proba(df_sel.iloc[:, 1:-1])[:, 1],pos_label=1)
        aucpr_xgbsel = metrics.average_precision_score(df_sel.iloc[:, -1], XGBsel.predict_proba(df_sel.iloc[:, 1:-1])[:, 1])
        ax.plot(recall, precision, label='Vasor (AUC-PR = %.2f)' % aucpr_xgbsel, color='green', linestyle='solid', linewidth=4)
        
        precision, recall, thresholds = precision_recall_curve(eve.iloc[:, -1], eve.iloc[:, -3])
        aucpr_eve = metrics.average_precision_score(eve.iloc[:, -1], eve.iloc[:, -3])
        ax.plot(recall, precision, color='black', label='EVE (AUC-PR = %.2f)' % aucpr_eve, linewidth=4)    
        precision, recall, thresholds = precision_recall_curve(polyphen.iloc[:, -1], polyphen.iloc[:, -3])
        aucpr_poly = metrics.average_precision_score(polyphen.iloc[:, -1], polyphen.iloc[:, -3])
        ax.plot(recall, precision, color='blue', label='Polyphen2 (AUC-PR = %.2f)' % aucpr_poly, linewidth=4)
        
        precision, recall, thresholds = precision_recall_curve(ponp2.iloc[:, -1], ponp2.iloc[:, -3])
        aucpr_p2 = metrics.average_precision_score(ponp2.iloc[:, -1], ponp2.iloc[:, -3])
        ax.plot(recall, precision, color='grey', label='PONP2 (AUC-PR = %.2f)' % aucpr_p2, linewidth=4)
        
        precision, recall, thresholds = precision_recall_curve(mutpred.iloc[:, -1], mutpred.iloc[:, -2])
        aucpr_mp = metrics.average_precision_score(mutpred.iloc[:, -1], mutpred.iloc[:, -2])
        ax.plot(recall, precision, color='red', label='MutPred2 (AUC-PR = %.2f)' % aucpr_mp, linewidth=4) 
        
        baseline = len(df[df['Pathogenicity'] == 1]) / len(df)
        ax.plot([0, 1], [baseline, baseline], linestyle='--', color='0.6', label='Chance')
        
        ax.set_xlabel('Recall', fontsize=16)
        ax.set_ylabel('Precision', fontsize=16)       
        ax.set_ylim([-0.05, 1.05])
        plt.tight_layout()
        plt.savefig(base_root+'images/Precision-Recall_predictors.png', dpi=300)
        plt.show()

        # Plot number of coverage of each predictor as bars
        models = ['Vasor', 'EVE', 'Polyphen2', 'PONP2', 'MutPred2'] 
        
        num_xgbsel = len(df_sel) / len(df_sel) * 100
        num_eve = len(eve) /len(df) * 100
        num_poly = len(polyphen) / len(df) * 100
        num_p2 = len(ponp2) / len(df) *100
        num_mp = len(mutpred) /len(df) * 100
        heights = [num_xgbsel, num_eve, num_poly, num_p2, num_mp] 
        fig, ax = plt.subplots(figsize= (4, 6))
        ax.bar(models, heights, color= ['green', '0.7', '0.7', '0.7', '0.7'])
        ax.set_ylabel('Coverage of variants in dataset [%]', fontsize=16)

        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(base_root+'images/Coverage_predictors.png', dpi=300)
        plt.show()

      
    def performance_table(self, testset):

        vasor_df = pandas.read_csv(self.repos+'Vasor_predictions_dataset.csv', sep = ',')
        ponp2 = pandas.read_csv(self.repos+'dataframes/PONP2_predictions.csv', sep = ',')
        eve = pandas.read_csv(self.repos+'dataframes/EVE_predictions.csv', sep = ',')
        polyphen = pandas.read_csv(self.repos+'dataframes/Polyphen2_predictions.csv', sep = ',')
        mutpred = pandas.read_csv(self.repos+'dataframes/MutPred_predictions.csv', sep = ',')
        # MutPred2 - cutoff at 0.5 
        mutpred['Classification'] = numpy.where(mutpred['Probability'] >= 0.5, 1, 0) 
        
        models = ['Vasor', 'EVE', 'Polyphen2', 'PONP2', 'MutPred2'] 
        
        def compare(y_true, y_pred):
            '''
            For other predictors.
            Input: y_true values, y_predicted values
            
            Returns: 
            - List of performance values (fpr, tpr, tnr, ppv, acc, f1, mcc, fp, tp, fn, tn, cov)
            '''
            fp = numpy.sum((y_pred == 1) & (y_true == 0))
            tp = numpy.sum((y_pred == 1) & (y_true == 1))
        
            fn = numpy.sum((y_pred == 0) & (y_true == 1))
            tn = numpy.sum((y_pred == 0) & (y_true == 0))
        
            tpr = (tp / (tp + fn)) # sensivity or recall
            tnr = (tn / (tn + fp)) # Specificity or Selectivity, TN / (TN + FP)
            ppv = (tp / (tp + fp)) # precision or positive predictive value
            npv = (tn / (tn + fn)) # negative predictive value
            acc = ((tp + tn) / (tp + tn + fp + fn)) # accuracy
            f1 = ((2*tp) / ((2*tp) + fp + fn)) # F1 score
            mcc = ((tp * tn - fp * fn) / numpy.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))  # Matthews correlation coefficient
            cov = sum((fp, tp, fn, tn)) / 364 * 100 # coverage: all predictions divided by number of var in dataset (364)
        
            return [[tpr, tnr, ppv, npv, acc, f1, mcc, tp, fn, tn, fp, cov]]

        # get values
        vasor_values = compare(vasor_df.iloc[:, -1], vasor_df.iloc[:, -2])
        eve_values = compare(eve.iloc[:, -1], eve.iloc[:, -2])
        poly_values = compare(polyphen.iloc[:, -1], polyphen.iloc[:, -2])
        ponp2_values = compare(ponp2.iloc[:, -1], ponp2.iloc[:, -2])
        mutpred_values = compare(mutpred.iloc[:,-2], mutpred.iloc[:, -1])
        
        # put values into a dataframe
        new_df = pandas.DataFrame(vasor_values)
        new_df = new_df.append(eve_values, ignore_index=True)
        new_df = new_df.append(poly_values, ignore_index=True)
        new_df = new_df.append(ponp2_values, ignore_index=True)
        new_df = new_df.append(mutpred_values, ignore_index=True)
        new_df.columns=['Recall', 'Specificity', 'Precision', 'NPR', 'Accuracy', 'F1-Score', 'MCC', 'TP', 'FN', 'TN', 'FP', 'Coverage']
        new_df.index = [m for m in models]
        new_df.round(decimals=2)
        
        # export to file
        new_df.to_csv(self.repos+'Performance_values.csv', sep = ',', index=True)        
        
        #######################################################################
        ## Performance values on testset ##
        #######################################################################
        
        # list of variants to create the values from: testset
        vasor_testset = vasor_df[vasor_df['Mutation'].isin(testset)]
        ponp2_testset = ponp2[ponp2['Mutation'].isin(testset)]
        eve_testset = eve[eve['Mutation'].isin(testset)]
        polyphen_testset = polyphen[polyphen['Mutation'].isin(testset)]
        mutpred_testset = mutpred[mutpred['Mutation'].isin(testset)]

        # get values
        vasor_values_testset = compare(vasor_testset.iloc[:, -1], vasor_testset.iloc[:, -2])
        eve_values_testset = compare(eve_testset.iloc[:, -1], eve_testset.iloc[:, -2])
        poly_values_testset = compare(polyphen_testset.iloc[:, -1], polyphen_testset.iloc[:, -2])
        ponp2_values_testset = compare(ponp2_testset.iloc[:, -1], ponp2_testset.iloc[:, -2])
        mutpred_values_testset = compare(mutpred_testset.iloc[:,-2], mutpred_testset.iloc[:, -1])
        
        # put values into a dataframe
        testset_performance = pandas.DataFrame(vasor_values_testset)
        testset_performance = testset_performance.append(eve_values_testset, ignore_index=True)
        testset_performance = testset_performance.append(poly_values_testset, ignore_index=True)
        testset_performance = testset_performance.append(ponp2_values_testset, ignore_index=True)
        testset_performance = testset_performance.append(mutpred_values_testset, ignore_index=True)
        testset_performance.columns=['Recall', 'Specificity', 'Precision', 'NPR', 'Accuracy', 'F1-Score', 'MCC', 'TP', 'FN', 'TN', 'FP', 'Coverage']
        testset_performance.index = [m for m in models]
        testset_performance.round(decimals=2)
        
        # change the coverage for the testset (40 variants) as it is computed for the overall dataset in the functions
        testset_performance['Coverage'] = testset_performance[['TN', 'FN', 'TP', 'FP']].sum(axis=1) / 40 * 100
        
        # export to file
        testset_performance.to_csv(self.repos+'Performance_values_testset.csv', sep = ',', index=True)  
        
        return(new_df, testset_performance)
    
    def histogram_probabilities(self, XGBsel, df, X_sm_new, y_sm_new, X_test, y_test):
        '''
        Function to create histogram of probability value distributions for benign and pathogenic variants.
        Parameters
        ----------
        XGBsel : ML model 
        df : entire dataset

        '''
        col_out = ['RSA', 'I_Mutant_stsign', 'ss_effect', 'I_Mutant_deltaG']
        df.drop(axis=1, columns=col_out, inplace= True)
        
        X = df.iloc[:, 1:-1]
        y = df.iloc[:, -1]
        
        prediction = XGBsel.predict_proba(X)[:,1]
            
        # Create histogram
        fig, ax = plt.subplots(figsize=(12,6), dpi=300)
        ax.hist(prediction[y==0], bins=numpy.linspace(0.0,1.0,num=100), label='Benign', edgecolor='k')
        ax.hist(prediction[y==1], bins=numpy.linspace(0.0,1.0,num=100), label='Pathogenic', edgecolor='k',
                alpha=0.6, color='red')
        
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xlabel('Probability of pathogenicity', fontsize=16)
        ax.set_ylabel('Number of variants', fontsize=16)
        ax.set_xticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        ax.set_xticklabels([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        plt.tight_layout()
        plt.savefig(base_root+'images/Probability_pathogenicity.png', dpi=300)
        plt.show()
        
        # include SMOTE values 
        # list y values for all:
        y_all = y_sm_new.append(y_test, ignore_index=True)
        
        # dataset for all:
        X_test.drop(axis=1, columns=col_out, inplace= True)
        X_all = X_sm_new.append(X_test, ignore_index=True)
        
        preds_sm = XGBsel.predict_proba(X_all)[:,1]
        
        # Differentiate between real datapoints and SMOTE points 
        smote = preds_sm[324:-40].copy()
        dataset = preds_sm[:324].copy()
        testset = preds_sm[-40:].copy()
        
        smote_y = y_all[324:-40].copy()
        dataset_y = y_all[:324].copy()
        testset_y = y_all[-40:].copy()
        
        real_points = numpy.append(dataset, testset)
        real_points_y = dataset_y.append(testset_y, ignore_index=True)
                
        # Create histogram
        fig, ax = plt.subplots(figsize=(12,6), dpi=300)
        ax.hist(real_points[real_points_y == 0], bins=numpy.linspace(0.0,1.0,num=100), label='Benign', edgecolor='k')
        ax.hist([real_points[real_points_y == 1], smote], bins=numpy.linspace(0.0,1.0,num=100), stacked=True,
                label=['Pathogenic', 'Pathogenic - SMOTE'], edgecolor='k', alpha=0.6, color=['red', 'orange'])
        
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xlabel('Probability of pathogenicity', fontsize=16)
        ax.set_ylabel('Number of variants', fontsize=16)
        ax.set_xticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        ax.set_xticklabels([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        plt.tight_layout()
        plt.savefig(base_root+'images/Probability_pathogenicity_SMOTE.png', dpi=300)
        plt.show()
        
        #### Obtaining the values of percentiles and highest benign predicted value/lowest pathogenic
        # For the real datapoints only 
        ben_high = prediction[y==0].max()
        path_low = prediction[y==1].min()
        
        # percentiles
        ben_74 = numpy.percentile(prediction[y == 0], 74)
        path_75 = numpy.percentile(prediction[y==1], (100-75))
        
        print('---'*15,
              '\nHighest benign value is {0:.2f}'.format(ben_high), 
              '\nLowest pathogenic value is {0:.2f}\n'.format(path_low),
              '---'*15,
              '\n74 % of benign variants are below value of {0:.2f}'.format(ben_74),
              '\n75 % of pathogenic variants are greater value than {0:.2f}'.format(path_75),
              '\n60 % of pathogenic variants are greater value than {0:.2f}'.format(numpy.percentile(prediction[y==1], (100-60))),
              '\nLowest SMOTE value is {0:.2f}'.format(smote[smote_y == 1].min()))
        
class Predict:
    def __init__(self):
        self.repos = base_root+'repository/data/'
        self.csv_file = self.repos+'dataframes/MDR3_all_variants.csv'

        # setting matplotlib defaults
        plt.rcParams.update(plt.rcParamsDefault)
        
    def predicting_all(self):
        df = pandas.read_csv(self.csv_file, sep=',')
        #replace the error values
        df.replace(to_replace = 99999, value = 0, inplace=True)
        new_df= df.rename(columns={'>P21439' : 'Mutation'})
        
        XGB = xgb.XGBClassifier()
        XGB.load_model(base_root+'Vasor_model.txt')
        
        # match column name of FINAL dataset csv file with repository csv file 
        match_columns = pandas.read_csv(base_root+'final_dataset_for_ML.csv', sep=',')
        col_out = ['RSA', 'I_Mutant_stsign', 'ss_effect', 'I_Mutant_deltaG']
        match_columns.drop(axis=1, columns=col_out, inplace= True)

        final_data = new_df.reindex(columns = match_columns.columns.tolist())
        
        final_data.drop(columns='Pathogenicity', inplace=True)
        cols = [x for x in final_data.columns.tolist() if str(x) != 'Mutation']
        final_data[cols] = final_data[cols].apply(pandas.to_numeric, errors= 'coerce')
        
        # df with only mutations, pred and prob
        short_df = pandas.DataFrame(columns=['Mutation', 'Prediction', 'Probability_of_Pathogenicity', 'Probability_of_benign'])
        
        short_df['Mutation'] = [x for x in final_data['Mutation']]
        short_df['Prediction'] = XGB.predict(final_data.iloc[:, 1:])
        short_df['Probability_of_Pathogenicity'] = XGB.predict_proba(final_data.iloc[:, 1:])[:, 1]
        short_df['Probability_of_benign'] = XGB.predict_proba(final_data.iloc[:, 1:])[:, 0]

        cols = [x for x in final_data.columns.tolist() if str(x) != 'Mutation']
        final_data[cols] = final_data[cols].apply(pandas.to_numeric, errors= 'coerce')
        
        final_data['Prediction'] = [x for x in short_df['Prediction']]
        final_data['Probability_of_Pathogenicity'] = [x for x in short_df['Probability_of_Pathogenicity']]
        
        final_data.to_csv(self.repos+'MDR3_all_variants_Predictions.csv', sep=',', index=False)
        short_df.to_csv(self.repos+'MDR3_Predictions_all.csv', sep=',', index=False)
        
        return (final_data, short_df)
    
    def heatmap_all(self):
        df = pandas.read_csv(self.repos+'MDR3_Predictions_all.csv', sep=',')

        mdr3_file = pandas.read_fwf(base_root+'MDR3/MDR3.fasta')
        mdr3 = mdr3_file.iloc[0, 0]
        
        wt_pos = []
        for pos, wt in enumerate(mdr3):
            wt_pos.append(wt+str((pos+1)))
            
        amino_acids = ['R', 'H', 'K', 'D', 'E', 'S', 'T', 'N', 'Q', 'C', 'G', 'P', 'A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W']
        
        # create a dataframe with wt_pos as columns and amino_acids as rows
        map_aa = pandas.DataFrame(columns = wt_pos, index= amino_acids)

        for var in wt_pos:
            for aa in amino_acids:
                if aa == var[0]:
                    value = 0
                else:
                    value = df.loc[df['Mutation'] == str(var+aa), 'Probability_of_Pathogenicity'].values[0]
                # enter value at the position in dataframe test
                map_aa.loc[map_aa.index == aa, map_aa.columns == str(var)] = value
        
        map_aa = map_aa.apply(pandas.to_numeric)
        map_aa.to_csv(self.repos+'dataframes/Heatmap_table.csv', sep=',')
        
        map_aa = pandas.read_csv(self.repos+'dataframes/Heatmap_table.csv', sep=',')
        map_aa.rename(columns={'Unnamed: 0' : 'aa'}, inplace=True)
        # storing the average pathogenicity of each amino acid over entire row in dict        
        aa_vul = {}
        for ind in map_aa.index:
            aa_vul[map_aa.iloc[ind, 0]] = map_aa.replace(0, numpy.nan).iloc[[ind], 1:].mean(axis=1, numeric_only=True, skipna=True).values[0]
        # sort it according to lowest to highest
        sorted_aa = sorted(aa_vul.items(), key = lambda x: x[1])
        
        new_map = pandas.DataFrame(columns=map_aa.columns.tolist())
        for i, (aa, value) in enumerate(sorted_aa):
            new_map.loc[i] = map_aa.loc[map_aa['aa'] == aa, :].values[0]
        
        new_map.to_csv(self.repos+'dataframes/Heatmap_table_sorted.csv', sep=',', index=False)
        heat = pandas.read_csv(self.repos+'dataframes/Heatmap_table_sorted.csv', sep=',', index_col=0)

        fig, ax = plt.subplots(figsize=(30,9), dpi=300)
        cbar_ax = fig.add_axes([.905, .05, .005, .82])
        sns.heatmap(heat, vmin=0, vmax=1, cmap="coolwarm", ax=ax, cbar_ax = cbar_ax)
        plt.savefig(base_root+'images/Heatmap_MDR3.png')
        
        return(map_aa)

    def install_test(self, final_data):
        # load in precalculated values for 5 variants, compare to calculation here and print to excel file
        test_df = pandas.read_csv(self.repos+'dataframes/Install_compare.csv', sep=';')
        
        XGB = xgb.XGBClassifier()
        XGB.load_model(base_root+'Vasor_model.txt')
        
        for var in test_df['Variant'].tolist():
            row = final_data.loc[final_data.Mutation == var]
            pred = XGB.predict_proba(row.iloc[:, 1:-2])[:, 1]
            
            test_df.loc[test_df['Variant'] == var, 'Local_install'] = pred
        
        test_df.to_csv(base_root+'Installation_Test.csv', sep=',', index=False)
        

if __name__ == '__main__':
    set_directory(base_root)
    preprocess = Preprocessing()
    testset = preprocess.Run()
    features = Visualize_corr()
    corr, rms = features.correlation()
    
    method = Method()
    XGB, X_train, X_test, y_train, y_test, X_sm, y_sm, df = method.MachineLearning()
    XGBsel, X_sm_new, y_sm_new, xgboost_df, result_xgb = method.feature_visualization(XGB, X_sm, y_sm)
    method.compare_predictors(XGBsel)
    perf_table, testset_performance = method.performance_table(testset)
    method.histogram_probabilities(XGBsel, df, X_sm_new, y_sm_new, X_test, y_test)
    
    predict = Predict()
    final_data, short_df = predict.predicting_all()
    map_aa = predict.heatmap_all()
    predict.install_test(final_data)
    

    
