"""
    Created Mar 2014
    This file is part of VisualCNA
    Copyright (2014) Prakash Chandra Rathi and Daniel Mulnaes

    VisualCNA is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    VisualCNA is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
#####################################################################
# -*- coding: utf-8 -*-
#
# Authors: Prakash Chandra Rathi and Daniel Mulnaes
# Heinrich Heine University, Duesseldorf
# Institute for Pharmaceutical and Medicinal Chemistry
# Universitaetsstr. 1 40225 Duesseldorf
# Germany
#####################################################################


# Standard python libraries
import os

# External Modules
from Bio import AlignIO
from Bio.PDB import PPBuilder, Selection
import numpy as np
try:
    import pyFirst
except:
    pass

# VisualCNA Modules
import network, errors

# TODO: May be remove the class and keep just the functions and methods should be able to import read_data itself not as an object of visualCNA
# this will require fixing cyclic call of read_data > network > methods > read_data

class Read_data:
    def read_gi(self, data_file, index_type_tag, hb_cut_offs):
        def myround(x, prec, base):
            return np.round(base*np.round(x/base), prec)
        standard_colnames={"li": ["No", "Resi", "Chain", "PercInd1", "PercInd2", "RigInd", "RicInd"],
                           "gi": ["No", "Egy", "Temp", "CCE1", "CCE2", "ROP1", "ROP2", "ROP3", "MCS"]}
        round_columns={"li": ["PercInd1", "PercInd2", "RigInd", "RicInd"], "gi": ["Egy"]}
        try:
            data=np.genfromtxt(data_file, names=True)
            if all(x in data.dtype.names for x in standard_colnames[index_type_tag]):
                if index_type_tag=="gi":
                    # get selected rows that correspond to hb_cutoffs
                    row_indices=[i for i, egy in enumerate(data["Egy"]) if egy in hb_cut_offs]
                    data=data[row_indices]
                for colname in round_columns[index_type_tag]:
                    data[colname]=myround(data[colname], 2, hb_cut_offs[0]-hb_cut_offs[1])
                return data
            else:
                return None
        except:
            return None

    def read_stbmap(self, stbmap_file):
        if 'neighbor' in stbmap_file:
            neighbor_stbmap_file=stbmap_file
            all_stbmap_file=neighbor_stbmap_file.replace('_neighbor', '')
        else:
            all_stbmap_file=stbmap_file
            neighbor_stbmap_file=stbmap_file.replace('_stability_map', '_neighbor_stability_map')

        # check if both files exist
        if not os.path.isfile(neighbor_stbmap_file):
            errors.input_error(12, (neighbor_stbmap_file))
            return None
        if not os.path.isfile(all_stbmap_file):
            errors.input_error(12, (all_stbmap_file))
            return None

        # If both files are available
        try:
            data=np.genfromtxt(neighbor_stbmap_file)
            dim=data.shape[0]
            data[np.triu_indices(dim)]=np.genfromtxt(all_stbmap_file)[np.triu_indices(dim)]
            return data
        except:
            return None

    def read_unf_nuc(self, data_file, cna_method, pdb_obj=None):
        if cna_method=='ensemble':
            data=np.genfromtxt(data_file, dtype=None, names=True)
            if not any(x in data.dtype.names for x in ["Type1", "Type2", "Type3", "Type4"]):
                errors.input_error(14, (data_file, "unfolding nuclei"))
            else:
                pdb_res_names=[res.resname for res in pdb_obj.res_objects]
                unf_resnames=list(data["RES_NAME"])
                if not pdb_res_names==unf_resnames:
                    errors.input_error(18, (data_file,))
                    return None
                else:
                    return data

        elif cna_method=='single':
            unf_nuc_initial={}
            f=open(data_file, 'r')
            for line in f.readlines():
                if 'Unfolding' in line:
                    u_type=line.rstrip().split('_')[-1]
                    unf_nuc_initial[u_type]=[]
                    continue
                try:
                    unf_nuc_initial[u_type].append((line.split()[1], int(line.split()[2]), line.split()[3]))
                except:
                    pass
            f.close()
            if not any(x in unf_nuc_initial.keys() for x in ["Type1", "Type2", "Type3", "Type4"]):
                errors.input_error(14, (data_file, "unfolding nuclei"))
                return None

            for unf_type in unf_nuc_initial:
                for unf_nuc in unf_nuc_initial[unf_type]:
                    try:
                        if pdb_obj.model[unf_nuc[2]][unf_nuc[1]].resname!=unf_nuc[0]:
                            errors.input_error(18, (data_file,))
                            return None
                    except:
                        errors.input_error(18, (data_file,))
                        return None

            # convert dictionary to structured array
            unf_nuc_new=[]
            unf_nuc_types=sorted(unf_nuc_initial.keys())
            for i, res in enumerate(Selection.unfold_entities(pdb_obj.model, 'R')):
                cur_row=[i+1, res.resname, res.id[1], res.parent.id]
                for unf_nuc_type in unf_nuc_types:
                    if (res.resname, res.id[1], res.parent.id) in unf_nuc_initial[unf_nuc_type]:
                        cur_row.append(100)
                    else:
                        cur_row.append(0)
                unf_nuc_new.append(tuple(cur_row))
            dtype=[('No', '<i8'), ('RES_NAME', '|S3'), ('RES_ID', '<i8'), ('CHAIN', '|S1')]
            for unf_nuc_type in unf_nuc_types:
                dtype.append((unf_nuc_type, '<i8'))
            data=np.array(unf_nuc_new, dtype=dtype)
            return data

    def read_transitions(self, data_file, cna_method):
        colnames=['CCE1', 'CCE1_t', 'CCE2_spline_e', 'CCE2_spline_t', 'ROP1', 'ROP1_t',
                        'ROP2', 'ROP2_t', 'ROP3_t', 'ROP3', 'CCE2', 'CCE2_t']

        dtype=[(name, '>f4') for name in colnames]
        try:
            f=file(data_file, "r")
            if cna_method=="single":
                for line in f.readlines():
                    if not line.startswith("#"):
                        data=line.split()[1:]
                        data=[float(i) if i!="nan" else None for i in data]
                        break
            elif cna_method=="ensemble":
                for line in f.readlines():
                    if line.startswith("#Median"):
                        data=line.split()[1:]
                        data=[float(i)  if i!="nan" else None for i in data]
                        break
            f.close()
            return np.array(tuple(data), dtype=dtype)
        except:
            data=[None]*len(colnames)
            return np.array(tuple(data), dtype=dtype)

    def read_msa(self, msa_file, pdb_obj, chain_id, base_seq_name):
        msa_format=msa_file.split('.')[-1]
        try:
            msa=AlignIO.read(open(msa_file), msa_format)
        except:
            errors.input_error(24, (msa_file, msa_format))
            return None, None
        # get sequence from pdb
        aa_list='-ACDEFGHIKLMNPQRSTVWY'
        aa_to_index=dict(zip(aa_list, range(len(aa_list))))
        ppb=PPBuilder()
        st=pdb_obj.structure
        model=st[0]
        try:
            chain=model[chain_id]
        except:
            'Error: Chain id %s is not found in %s.pdb'%chain_id, pdb_obj.pdb_id
            return None, None
        pdb_seq=str(ppb.build_peptides(chain)[0].get_sequence())
        seq_count=len(msa)
        seq_length=msa.get_alignment_length()
        all_seq_list=[]
        # create a single long string containing all the sequences
        base_seq=''
        for i, sequence in enumerate(msa):
            if sequence.id==base_seq_name:
                base_seq=str(sequence.seq)
                base_seq_no_gap=base_seq.replace('-', '')
                if  base_seq_no_gap!=pdb_seq:
                    errors.input_error(8)
                    return None, None
            else:
                all_seq_list+=list(str(sequence.seq))

        if base_seq=='':
            errors.input_error(23, (base_seq_name, msa_file))
            return None, None
        del msa
        # create numpy character array and reshape in m*n shape(m sequences of n length)
        msa_matrix=np.char.array(all_seq_list)
        if len(msa_matrix)!=(seq_count-1)*seq_length:
            errors.input_error(9)
            return None, None
        msa_matrix=msa_matrix.reshape(seq_count-1, seq_length)
        # do conservation analysis
        all_aa_frequencies=[]
        base_seq_conservations=[]
        for i, aa_base_seq in enumerate(base_seq):
            if aa_base_seq!='-':
                aa_frequencies=[0]*len(aa_list)
                for aa in msa_matrix[:, i]:
                    if aa in aa_to_index:
                        aa_frequencies[aa_to_index[aa]]+=1
                all_aa_frequencies.append([aa_frequency/float(seq_count-1)*100 for aa_frequency in aa_frequencies])
                base_seq_conservations.append(aa_frequencies[aa_to_index[aa_base_seq]]/float(seq_count-1)*100)

        del msa_matrix
        return all_aa_frequencies, base_seq_conservations

    def read_constraint_file(self, parent, data_file):
        net_obj=network.Network(parent)
        for line in open(data_file, 'r').readlines():
            if any([line.startswith(i) for i in ['COV', 'SRING', 'HBOND', 'HPHOBES']]):
                split_line=line.split()
                if len(split_line)<4:
                    continue

                atom_id_1=int(split_line[1])
                atom_id_2=int(split_line[2])
                if not atom_id_1 in parent.pdb_obj.atom_ids or not atom_id_2 in parent.pdb_obj.atom_ids:
                    errors.input_error(19, (data_file))
                    return None

                if len(split_line)>7:
                    if split_line[7]=='SB':
                        ctype='SBRIDGE'
                    elif split_line[7]=='HB':
                        ctype='HBOND'
                    elif split_line[7]=='CUSTOM':
                        ctype="CUSTOM"
                    else:
                        continue
                else:
                    if split_line[0]=='COV':
                        # check for distance if they are within a cutoff to for covalent bond
                        x1, y1, z1=parent.pdb_obj.atom_details[atom_id_1][5:]
                        x2, y2, z2=parent.pdb_obj.atom_details[atom_id_2][5:]
                        distance=((x1-x2)**2+(y1-y2)**2+(z1-z2)**2)**0.5
                        if distance>3.0:
                            errors.input_error(19, (data_file))
                            return None
                        if parent.pdb_obj.atom_details[atom_id_1][1]=='SG' and parent.pdb_obj.atom_details[atom_id_2][1]=='SG':
                            ctype='DBRIDGE'
                        elif split_line[4]==split_line[5]=='HETATM':
                            ctype='LIG_COV'
                        else:
                            ctype='COV'
                    else:
                        ctype=split_line[0]
                bars=int(split_line[3])
                if ctype in ['HBOND', 'SBRIDGE', 'HPHOBES']:
                    energy=float(split_line[4])
                else:
                    energy=None
                net_obj.add(ctype, bars, energy, atom_id_1, atom_id_2)

        # Merge network file with network from pdb in case constraints are missing in file
        # print net_obj.constraints
        if parent.cna_exe is not None:
            file_network_types=list(set([net_obj.constraints[c]['type'] for c in net_obj.constraints]))
            merge_network_obj=self.read_pdb_network(parent, parent.pdb_file)
            merge_constraints={c:merge_network_obj.constraints[c] for c in merge_network_obj.constraints if merge_network_obj.constraints[c]['type'] not in file_network_types}
            for c in merge_constraints:
                net_obj.constraints[c]=merge_constraints[c]
        return net_obj

    def read_pdb_network(self, parent, pdb_file):
        if parent.cna_exe is not None:
            first_obj=pyFirst.PyFirst(pdb_file, max(parent.hp_cut_offs))
        else:
            errors.attention(11)
            return

        net_obj=network.Network(parent)
        covs=first_obj.getCovalentInteractions()
        # Modify covalent network for ligands
        lig_network, lig_ids=parent.pdb_obj.prepare_ligands_for_analysis()
        if lig_network:
            # delete ligand covalent constraints identified by FIRST
            for key in covs.keys():
                if key[0] in lig_ids or key[1] in lig_ids:
                    del covs[key]
            # Add covalent constraints identified by openbabel
            for key in lig_network.keys():
                covs[key]=lig_network[key]
        hbonds=first_obj.getHydrogenBonds()
        hphobes=first_obj.getHydrophobicTethers()
        srings=first_obj.getStackedRing()

        for c in covs:
            atom_id_1=min(c)
            atom_id_2=max(c)
            if parent.pdb_obj.atom_details[atom_id_1][1]=='SG' and parent.pdb_obj.atom_details[atom_id_2][1]=='SG':
                net_obj.add('DBRIDGE', covs[c], None, atom_id_1, atom_id_2)
            elif atom_id_1 in lig_ids or atom_id_2 in lig_ids:
                net_obj.add('LIG_COV', covs[c], None, atom_id_1, atom_id_2)
            else:
                net_obj.add('COV', covs[c], None, atom_id_1, atom_id_2)

        for c in hbonds:
            atom_id_1=c[0]
            atom_id_2=c[1]
            donor=first_obj.getDonorAtom(atom_id_1)
            bars=first_obj.getNumberOfBars(atom_id_1, atom_id_2)
            if first_obj.isSaltBridge(donor, atom_id_1, atom_id_2):
                net_obj.add('SBRIDGE', bars, hbonds[c], atom_id_1, atom_id_2)
            else:
                net_obj.add('HBOND', bars, hbonds[c], atom_id_1, atom_id_2)

        for c in hphobes:
            atom_id_1=min(c)
            atom_id_2=max(c)
            bars=first_obj.getNumberOfBars(atom_id_1, atom_id_2)
            net_obj.add('HPHOBES', bars, hphobes[c], atom_id_1, atom_id_2)

        for c in srings:
            atom_id_1=min(c)
            atom_id_2=max(c)
            bars=first_obj.getNumberOfBars(atom_id_1, atom_id_2)
            dist=first_obj.getDistance(atom_id_1, atom_id_2)
            net_obj.add('SRING', bars, dist, atom_id_1, atom_id_2)

        return net_obj

# main method for testing
if __name__=="__main__":
    pass
