"""
    Created Mar 2014
    This file is part of VisualCNA
    Copyright (2014) Prakash Chandra Rathi and Daniel Mulnaes

    VisualCNA is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    VisualCNA is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
#####################################################################
# -*- coding: utf-8 -*-
#
# Authors: Prakash Chandra Rathi and Daniel Mulnaes
# Heinrich Heine University, Duesseldorf
# Institute for Pharmaceutical and Medicinal Chemistry
# Universitaetsstr. 1 40225 Duesseldorf
# Germany
#####################################################################

# Standard python libraries
import os, gc
from collections import Counter

# External modules
from pymol import cmd
from pymol.cgo import *
from Bio.PDB import *
import numpy as np
import scipy.cluster.hierarchy as sch

# VisualCNA Modules
import errors

class ShowNetwork:

    def __init__(self, parent, stb_file=None, gi_file=None):
        self.rc_colors=['0x0000b2', '0x57eb0f', '0xe75ab7', '0x23999b', '0xc19b44', \
                          '0x9e85c8', '0x04c60b', '0xeb2760', '0x73bbe7', '0xa3aa2c', \
                          '0xa75bc3', '0x9acdae', '0xda1809', '0x426ce9', '0xc1e88b', \
                          '0xb936ad', '0x71c5b1', '0xb34e00', '0x270fec', '0x79e85b', \
                          '0x9c2365', '0x45b2c2', '0xc9bc86', '0x6004c7', '0x28eb45', \
                          '0xe77489', '0x2c6dac', '0xb2c35c', '0xc59acd', '0x09da74']
        self.parent=parent
        self.bond_thickness=0.1
        self.E_cut_off_states=zip(parent.hb_cut_offs, range(1, len(parent.hb_cut_offs)+1))
        self.atom_id_clusters={state:{} for _, state in self.E_cut_off_states}
        self.state_con_groups={state:{} for _, state in self.E_cut_off_states}
        self.rigidcluster_ids=[]
        self.kill_old_observers=False

        if stb_file not in ['', None]:
            self.min_rc_size=3
            self.get_rcds(parent.rd.read_stbmap(stb_file), gi_file)
            self.resi_rc_matrix={}
            self.rc_labels=[]

    def get_rcds(self, stb_map, gi_file):
        # setup gi calculation
        if gi_file is not None:
            # Import CNA modules
            import pdb, network_analysis, output_results
            cna_pdb_obj=pdb.PDB(self.parent.pdb_file)
            class CNAOptions:
                res_dir=os.path.dirname(gi_file)
                aic_selection=False

            options=CNAOptions
            network=network_analysis.NetworkAnalysis()
            globInd=network_analysis.GlobalIndices(network, cna_pdb_obj, options)

        # Calculate Clustering
        stb_map_upper=stb_map[np.triu_indices(stb_map.shape[0], 1)]  # Non-redundant distance matrix
        stb_map_upper=np.where(stb_map_upper==0,-0.001, stb_map_upper)
        stb_map_upper_mod=(stb_map_upper*-1)**-1
        clustering_object=sch.linkage(stb_map_upper_mod, method='single')  # Linkage clustering_obj

        # Get RCDs for each state
        for E_cut, state in self.E_cut_off_states:
            # Get unordered RCD labels and re-order them so that largest cluster has label 1 and clusters < min_size has label -1
            if E_cut==0:
                E_cut=-0.001
            rcd_labels=sch.fcluster(clustering_object, t=1.0/(E_cut*-1), criterion='distance')
            rcd_unordered={i+1:rcd_labels[i] for i, _ in enumerate(rcd_labels)}
            rcd_ressize=sorted(Counter(rcd_unordered.values()).items(), key=lambda t:t[1], reverse=True)
            rcd_labels={t[0]:i+1 for i, t in enumerate(rcd_ressize)}
            rcd_new_ressize_dict={i+1:t[1] for i, t in enumerate(rcd_ressize)}
            rcd_ordered={resi:rcd_labels[rcd_unordered[resi]] for resi in rcd_unordered}
            # Assign each atom in the to a cluster
            for resi in rcd_ordered:
                for atom_id in self.parent.pdb_obj.res_atom_id_list[resi]:
                    if rcd_ordered[resi] not in self.rigidcluster_ids and rcd_new_ressize_dict[rcd_ordered[resi]]>=self.min_rc_size:
                        self.rigidcluster_ids.append(rcd_ordered[resi])
                    self.atom_id_clusters[state][atom_id]=rcd_ordered[resi]

            # calculate gi for current cut_off
            if gi_file is not None:
                atom_wise_rcd=self.atom_id_clusters[state]
                # global index FMD can not be correctly calculated and this is not displayed on the plot
                network.prepare_network_analysis(E_cut, atom_wise_rcd, 1, 100)
                globInd.calculate_global_indices()

            # Now reset labels so that the RCs < critical size gets a label '-1'
            for atom_id in self.atom_id_clusters[state]:
                if rcd_new_ressize_dict[self.atom_id_clusters[state][atom_id]]<self.min_rc_size:
                    self.atom_id_clusters[state][atom_id]=-1
        # Write out gi file
        if gi_file is not None:
            gi_file=os.path.basename(gi_file).replace("_global_indices.dat", "")
            # FIXME This try except is for new and older versions of CNA with(new) and without(old) the file_name_prefix flag
            try:
                output_results.write_global_dilution_results(res_dir=options.res_dir, pdb_id=gi_file, globInd=globInd, fnc_id=None)
            except:
                output_results.write_global_dilution_results(res_dir=options.res_dir, pdb_id=gi_file, globInd=globInd, file_name_prefix='', fnc_id=None)

    def draw_clusters(self):
        if self.rigidcluster_ids==[]:
            errors.input_error(7)

        for _, state in self.E_cut_off_states:
            for rc in self.rigidcluster_ids:
                if rc!=-1:
                    rc_atom_ids=[i for i in self.atom_id_clusters[state] if self.atom_id_clusters[state][i]==rc]
                    col=self.rc_colors[(rc-1)%30]
                    rc_label='RC%d'%rc
                    cmd.select_list('test', self.parent.pdb_obj.pdb_id, rc_atom_ids)
                    cmd.create(rc_label, 'test', 1, state)
                    cmd.color(col, rc_label)
                    cmd.show('surface', rc_label)
                    cmd.disable(rc_label)
                    self.rc_labels.append(rc_label)
        cmd.delete('test')

    def make_groups(self):
        group_labels=[]
        for E_cut, state in self.E_cut_off_states:
            for (atom_id_1, atom_id_2) in self.parent.net_obj.constraints:
                constraint=self.parent.net_obj.constraints[(atom_id_1, atom_id_2)]

                # Do not group Covaelent constraints
                if constraint['type'] not in ['HPHOBES', 'HBOND', 'SBRIDGE']:
                    continue
                try:
                    atom_1_rc=self.atom_id_clusters[state][atom_id_1]
                    atom_2_rc=self.atom_id_clusters[state][atom_id_2]
                except:
                    atom_1_rc=0
                    atom_2_rc=0

                # Group Hydrophobic tethers
                if constraint['type']=='HPHOBES':
                    if atom_1_rc==-1 or atom_2_rc==-1:
                        group_name='HP_Flexible'
                    elif atom_1_rc==atom_2_rc:
                        group_name='HP_RC'+str(atom_1_rc)
                        if group_name not in group_labels:
                            group_labels.append(group_name)
                    else:
                        group_name='HP_Linking'

                # Group Hydrogen Bonds and Salt Bridges
                elif constraint['type'] in ['HBOND', 'SBRIDGE'] and constraint['energy']<E_cut:
                    if state!=len(self.E_cut_off_states) and constraint['energy']>self.parent.hb_cut_offs[state]:
                        group_name='HB_Breaking'
                    elif atom_1_rc==-1 or atom_2_rc==-1:
                        group_name='HB_Flexible'
                    elif atom_1_rc==atom_2_rc:
                        group_name='HB_RC'+str(atom_1_rc)
                        if group_name not in group_labels:
                            group_labels.append(group_name)
                    else:
                        group_name='HB_Linking'
                else:
                    group_name='Broken'

                # Assign constraint to group
                if not self.state_con_groups[state].has_key(group_name):
                    self.state_con_groups[state][group_name]=[]
                self.state_con_groups[state][group_name].append(constraint['label'])

        # Make order of grouping
        self.group_order=['Broken', 'HB_Breaking', 'HB_Flexible', 'HB_Linking', 'HP_Flexible', 'HP_Linking']
        self.group_order=self.group_order+[s[1] for s in sorted([(int(i.split('_')[-1][2:]), i) for i in group_labels if i not in self.group_order])]

        # Free up memory by deleting now redundant atom_id_clusters
        del self.atom_id_clusters
        gc.collect()

    def draw_cylinder(self, atom_id_1, atom_id_2, color, label, state):
        (x1, y1, z1)=self.parent.pdb_obj.atom_details[atom_id_1][5:]
        (x2, y2, z2)=self.parent.pdb_obj.atom_details[atom_id_2][5:]
        (r, g, b)=cmd.get_color_tuple(color)
        cylinder=[25.0, 1.0, 9.0, x1, y1, z1, x2, y2, z2, self.bond_thickness, r, g, b, r, g, b]
        cmd.load_cgo(cylinder, label, state)

    def draw_dashed_cylinder(self, atom_id_1, atom_id_2, color, label, state):
        cmd.distance(label, "id %d"%atom_id_1, "id %d"%atom_id_2)
        cmd.hide('labels', label)
        cmd.color(color, label)

    def draw_constraints(self, state , force=False):
        view=cmd.get_view()
        cmd.set('auto_zoom', 0)

        # Draw constraints in this state
        for (atom_id_1, atom_id_2) in self.parent.net_obj.state_to_net[state]['create']:
            if not force and self.kill_old_observers:
                return
            constraint=self.parent.net_obj.constraints[(atom_id_1, atom_id_2)]
            if constraint['type']=='DBRIDGE':
                if constraint['status']<1:
                    cmd.unbond('id %d'%atom_id_1, 'id %d'%atom_id_2)
                else:
                    cmd.bond('id %d'%atom_id_1, 'id %d'%atom_id_2)
            if constraint['status']==0:
                self.draw_cylinder(atom_id_1, atom_id_2, constraint['color'], constraint['label'], state=0)
                self.parent.net_obj.constraints[(atom_id_1, atom_id_2)]['status']=1
            elif constraint['status']==-1:
                self.draw_dashed_cylinder(atom_id_1, atom_id_2, constraint['color'], constraint['label'], state=0)

        # Delete constraints not in this state
        for (atom_id_1, atom_id_2) in self.parent.net_obj.state_to_net[state]['delete']:
            if not force and self.kill_old_observers:
                return
            constraint=self.parent.net_obj.constraints[(atom_id_1, atom_id_2)]
            if constraint['status']==1:
                self.parent.net_obj.constraints[(atom_id_1, atom_id_2)]['status']=0
            cmd.delete(constraint['label'])

        # Group Constraints
        groups_to_delete=[]
        for group_name in self.group_order:
            if not force and self.kill_old_observers:
                return
            state_groups=self.state_con_groups[self.parent.pymol_state]
            if group_name!='Broken' and group_name in state_groups:
                group_labels=state_groups[group_name]
                group_broken=[self.parent.net_obj.constraints[self.parent.net_obj.label_to_id[l]]['status']==0 for l in group_labels]
                if not all(group_broken):
                    cmd.group(group_name, ' + '.join(group_labels))
            elif group_name not in state_groups and state!=0:
                groups_to_delete.append(group_name)
        # Delete empty groups
        for group_name in groups_to_delete:
            cmd.delete(group_name)

        cmd.set_view(view)
        cmd.set('auto_zoom', 1)

    def force_network_draw(self, labels):
        types=list(set([self.parent.net_obj.short_to_long[l[:2]] for l in labels]))
        if any(t in types for t in ['HBOND', 'SBRIDGE', 'HPHOBES']):
            self.draw_constraints(self.parent.pymol_state, force=True)
        if any(t in types for t in ['CUSTOM', 'DBRIDGE', 'SRING', 'HPHOBES']):
            self.draw_constraints(0, force=True)

    def destroy(self):
        for state in self.state_con_groups:
            for group in self.state_con_groups[state]:
                try:
                    cmd.delete(group)
                except:
                    pass
        try:
            for rc in self.rc_labels:
                try:
                    cmd.delete(rc)
                except:
                    pass
        except:
            pass

        del self.parent.display
