"""
    Created Mar 2014
    This file is part of VisualCNA
    Copyright (2014) Prakash Chandra Rathi and Daniel Mulnaes

    VisualCNA is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    VisualCNA is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
#####################################################################
# -*- coding: utf-8 -*-
#
# Authors: Prakash Chandra Rathi and Daniel Mulnaes
# Heinrich Heine University, Duesseldorf
# Institute for Pharmaceutical and Medicinal Chemistry
# Universitaetsstr. 1 40225 Duesseldorf
# Germany
#####################################################################


# Standard python libraries
import os, re, sys

# External Modules
from Bio.PDB import *

# The following way to import the openbabel is required because it sets the DL loader flag
# which in turn leads to a seqfault if the scipy.linalg module is imported
try:
    flags=sys.getdlopenflags()
    import openbabel
    sys.setdlopenflags(flags)
except:
    pass

# VisualCNA Modules
import errors

class ReadPDB:
    def __init__(self, pdb_file, pdb_id=None):
        self.pdb_file=pdb_file
        self.parser=PDBParser(PERMISSIVE=False, QUIET=True)
        self.pdbio=PDBIO()
        try:
            self.filename=pdb_file
            if pdb_id is None:
                self.pdb_id=os.path.basename(self.pdb_file).split(".pdb")[0]
            else:
                self.pdb_id=pdb_id
            self.structure=self.parser.get_structure(self.pdb_id, self.pdb_file)
            if len(self.structure)>1:
                print "WARNING::Structure has multiple models. Only the first model will be considered"
            self.model=self.structure[0]
            self.atom_objs=Selection.unfold_entities(self.model, 'A')
            self.res_objects=[res for res in Selection.unfold_entities(self.model, 'R') if res.has_id('CA') and not self._check_hetatm_entry(res['CA'])]
            self.atom_ids=[atom.get_serial_number() for atom in self.atom_objs]
            self.atom_id_range=(min(self.atom_ids), max(self.atom_ids))
            self.atom_id_object=dict(zip(self.atom_ids, self.atom_objs))
            self.resi_serial_pymol_selector=self.__get__resi_serial_pymol_selector()
            self.pymol_selector_resi_serial={self.resi_serial_pymol_selector[k]:k for k in self.resi_serial_pymol_selector}
            self.atom_details, self.res_atom_id_list=self.__get_atom_details()
        except:
            errors.input_error(21, (pdb_file))

    def __get_atom_details(self):
        atom_details={}
        res_atom_id_list=[]
        for atom_obj in self.atom_objs:
            full_id=atom_obj.get_full_id()
            record="ATOM"
            if full_id[3][0]!=" ":
                record="HETATM"
            serial=atom_obj.get_serial_number()
            name=full_id[4][0]
            resn=atom_obj.parent.get_resname()
            chain=full_id[2]
            resi=full_id[3][1]
            x=atom_obj.get_coord()[0]
            y=atom_obj.get_coord()[1]
            z=atom_obj.get_coord()[2]
            atom_details[serial]=(record, name, resn, chain, resi, x, y, z)
            if name=="CA":
                res_atom_id_list.append([a.get_serial_number() for a in atom_obj.parent.child_list])
        return (atom_details, dict(zip(range(1, len(res_atom_id_list)+1), res_atom_id_list)))

    def __get__resi_serial_pymol_selector(self):
        resi_serial_pymol_selector=[]
        for resi in Selection.unfold_entities(self.model, 'R'):
            full_id=resi.get_full_id()
            if full_id[3][0]==" ":  # not HET
                pymol_selector="%s/%d/"%(full_id[2], full_id[3][1])
                resi_serial_pymol_selector.append(pymol_selector)
        return dict(zip(range(1, len(resi_serial_pymol_selector)+1), resi_serial_pymol_selector))


    def prepare_structure_for_anaylsis(self):
        """ Prepare structure to ensure a CNA conform format """
        # Modify Amber residue names (HIE,HID,HIP,CYX) to regular names
        toModify={'HIE': 'HIS',
                    'HID': 'HIS',
                    'HIP': 'HIS',
                    'CYX': 'CYS'}
        for residue in Selection.unfold_entities(self.model, 'R'):
            if residue.resname in toModify:
                residue.resname=toModify[residue.resname]
            if residue.is_disordered():
                errors.input_error(10)
                sys.exit(2)

    def check_missing_atoms(self):
        non_standard_residues=[]
        incomplete_residues=[]
        no_hydrogen_residues=[]
        less_hydrogen_residues=[]
        multiple_conf_residues=[]
        pattern_hydrogen=re.compile('[0-9]*H[A-Z]*')
        standard_residue_atoms={'ALA': ['C', 'CA', 'CB', 'N', 'O'],
                                     'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'NH1', 'NH2', 'O'],
                                     'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
                                     'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
                                     'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
                                     'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
                                     'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
                                     'GLY': ['C', 'CA', 'N', 'O'],
                                     'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
                                     'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
                                     'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
                                     'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
                                     'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
                                     'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
                                     'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
                                     'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
                                     'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
                                     'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'N', 'NE1', 'O'],
                                     'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', 'OH'],
                                     'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'],
                                     'HOH': ['O'],
                                     'WAT': ['O']}
        minimum_Hydrogens={'ALA': 5, 'ARG': 13, 'ASP': 4, 'ASN': 6, 'CYS': 4, 'GLU': 6, 'GLN': 8, 'GLY': 3,
                             'HIS': 7, 'ILE': 11, 'LEU': 11, 'LYS': 13, 'MET': 9, 'PHE': 9, 'PRO': 7, 'SER': 5,
                             'THR': 7, 'TRP': 10, 'TYR': 9, 'VAL': 9, 'HOH': 2, 'WAT': 2}
        for res in Selection.unfold_entities(self.model, 'R'):
            if res.resname not in standard_residue_atoms.keys():
                non_standard_residues.append(res)
                if res.get_full_id()[3][0]==" ":  # Not a HET res
                    if not all(atom_name in res.child_dict.keys() for atom_name in ['CA', 'N', 'C']):
                        incomplete_residues.append(res)
            else:
                if any(res.child_dict[atom].disordered_flag==1 for atom in res.child_dict):
                    multiple_conf_residues.append(res)
                if not all(atom_name in res.child_dict.keys() for atom_name in standard_residue_atoms[res.resname]):
                    incomplete_residues.append(res)
                res_hydrogen_count=len([atom_name for atom_name in res.child_dict.keys() if pattern_hydrogen.match(atom_name) is not None])
                if res_hydrogen_count==0:
                    no_hydrogen_residues.append(res)
                elif res_hydrogen_count<minimum_Hydrogens[res.resname]:
                    less_hydrogen_residues.append(res)
        return (non_standard_residues, incomplete_residues, less_hydrogen_residues, no_hydrogen_residues, multiple_conf_residues)

    def is_ligand_atom(self, atom_id):
        my_atom=self.get_atom_object_from_id(atom_id)
        my_res=my_atom.get_parent()
        if my_res.get_full_id()[-1][0]==" ":
            return False
        else:
            return True

    def _check_hetatm_entry(self, atom_obj):
        if atom_obj.get_full_id()[3][0]==" ":
            return False
        return True

    def prepare_ligands_for_analysis(self, res_dir=False):
        ligand_network={}
        ligand_ids=[]
        residues=Selection.unfold_entities(self.structure, "R")
        for act_res in residues:
            residue=act_res.get_full_id()[3]
            if residue[0] not in [" ", "W"]:
                inst_network={}
                # keep the original atom serial numbers
                org_atom_ids=[atm.get_serial_number() for atm in Selection.unfold_entities(act_res, "A")]
                if len(org_atom_ids)<7: continue  # Ignore all molecules with less than 7 atoms
                ligand_ids.extend(org_atom_ids)
                resi=residue[1]
                resn=act_res.get_resname()
                chain=act_res.get_full_id()[2]
                ligand_name="Lig_%s_%s_%s"%(resn, resi, chain)
                self.write_pdb_file(ligand_name+".pdb", ligand=(resn, resi, chain))
                inst_network=self.get_ligand_network(ligand_name, res_dir)
                for key in inst_network.keys():
                    org_atom1_id=org_atom_ids[key[0]-1]
                    org_atom2_id=org_atom_ids[key[1]-1]
                    ligand_network[(org_atom1_id, org_atom2_id)]=inst_network[key]
                os.remove(ligand_name+".pdb")
        return ligand_network, ligand_ids


    def get_ligand_network(self, lig_name=False, res_dir=False):
        in_format="pdb"
        out_format="mol2"
        obConversion=openbabel.OBConversion()
        obConversion.SetInAndOutFormats(in_format, out_format)
        mol=openbabel.OBMol()
        obConversion.ReadFile(mol, lig_name+"."+in_format)
        if res_dir:
            obConversion.WriteFile(mol, "%s/%s.%s"%(res_dir, lig_name, out_format))
        ligand_network={}
        for obbond in openbabel.OBMolBondIter(mol):
            bond_order=str(obbond.GetBondOrder())
            if obbond.IsAromatic():
                bond_order="ar"
            elif obbond.IsAmide():
                bond_order="am"
            if bond_order=="1":
                bars=5
            elif bond_order in ["2", "3", "am", "ar"]:
                bars=6
            else:
                bars=5
            ligand_network[(obbond.GetBeginAtom().GetIdx(), obbond.GetEndAtom().GetIdx())]=bars
        return ligand_network

    def write_pdb_file(self, filename, ligand=None):
        """ Will select the first alt loc in case of a disordered atom
            and first model in case of a multi-model structure"""
        class NotDisorderedFirstModel(Select):
            def accept_model(self, model):
                if model.id==0:
                    return True
        """ Will select the ligand atoms of ligand = (resn, resi)"""
        class LigandSelection(Select):
            def accept_residue(self, residue):
                if residue.id[1]==ligand[1] and residue.get_resname()==ligand[0] and residue.get_full_id()[2]==ligand[2]:
                    return True
        io=PDBIO()
        io.set_structure(self.structure)
        if not ligand:
            io.save(filename, select=NotDisorderedFirstModel(), conserve_atoms_number=True)
        else:
            io.save(filename, select=LigandSelection(), conserve_atoms_number=True)

if __name__=="__main__":
    ReadPDB(sys.argv[1])
