#!/usr/bin/env python

import psycopg2
import os
import time
import random
import math
import pprint
import sys
import unittest


class DBManager():
    def __init__(self, dbname, user, password, host):
        self.dbname = dbname
        self.user = user
        self.password = password
        self.host = host
        self.t_insert = []
        self.t_move = []
        self.pp = pprint.PrettyPrinter(indent=4)
        self.stats = {}
        self.reset_stats()

    def reset_stats(self):
        self.stats["insert"] = []
        self.stats["move"] = []
        self.stats["nstree"] = []
        self.stats["adjlst"] = []

    def save_stats(self, basename):
        ts = time.strftime("%Y%m%d%H%M%S")
        os.mkdir(ts)
        """save all statistics"""
        for k in self.stats.keys():
            if len(self.stats[k]) > 0:
                f = open(os.path.join(ts,"test-%s-%s.out" % (k, basename)), "w")
                for i, v in enumerate(self.stats[k]):
                    f.write("%s %s\n" % (i, v))
                f.close()
        
    def connect(self):
        self.conn = psycopg2.connect("dbname=%s user=%s password=%s host=%s" % (self.dbname, self.user, self.password, self.host))
        self.curs = self.conn.cursor()
        print "connected"
        
    def close(self):
        self.conn.close()
        
    def insert_node(self, id_parent,  name):
        """insert one node"""
        t1 = time.time()
        self.curs.execute("""INSERT into nodes (id_parent,name) values ('%s','%s') """ % (id_parent,  name))
        self.conn.commit()
        t2 = time.time()
        t = t2 - t1
        self.stats["insert"].append(t)

    def delete_all_nodes(self):
        print "del"
        try:
            self.curs.execute("DROP TRIGGER nodes_delete_tr ON nodes")
        except:
            self.conn.rollback()

        self.curs.execute("DELETE from attributes where nodes_id > 0")
        self.conn.commit()
        self.curs.execute("DELETE FROM nodes where id > 0")
        self.conn.commit()
        print "end del"
    def delete_all_attributes(self):
        self.curs.execute("DELETE FROM attributes")
        self.conn.commit()

    def insert_nodes(self, num_nodes):
        """insert num_nodes for each leaves in the tree"""
        print "insert nodes"
        self.curs.execute("SELECT t1.id FROM nodes AS t1 LEFT JOIN nodes as t2 ON t1.id = t2.id_parent WHERE t2.id IS NULL AND t1.id != -1")
        leaves = self.curs.fetchall()
        if len(leaves) == 0:
            for j in range(num_nodes):
                self.insert_node(0, "test")
        else:
            for leaf in leaves:
                for j in range(num_nodes):
                    self.insert_node(leaf[0], "test")
     
    def move_node(self, id, new_parent):
        """Move a node in the tree"""
        t1 = time.time()
        self.curs.execute("UPDATE nodes SET id_parent = '%s' WHERE id=%s" %  (new_parent,  id))
        self.conn.commit()
        t2 = time.time()
        t = t2 - t1
        self.stats["move"].append(t)
        
    def move_middle_node_to_root(self):
        id = self.get_middle_node()
        self.curs.execute("SELECT id_parent FROM nodes WHERE id='%s'" % (id))
        parent = self.curs.fetchall()
        self.move_node(id, 0)

    def attribute_query_leaf_nstree(self, node_id):
        """Merge attributes for one leaf, nstree model"""
        print "nstree"
        attrs = {}
        q = """SELECT id, leftval, rightval FROM nodes WHERE id = %s""" % node_id
        self.curs.execute(q)
        rep = self.curs.fetchall()
        leftval = rep[0][1]
        rightval = rep[0][2]
        q  = """SELECT * FROM nodes N INNER JOIN attributes A ON A.nodes_id=N.id """
        q += """LEFT JOIN attributesDef AD ON A.attributesdef_id=AD.id """
        q += """WHERE N.leftval<%s AND N.rightval>=%s ORDER BY N.leftval""" % (rightval, rightval)
        t1 = time.time()
        self.curs.execute(q)
        res = self.curs.fetchall()
        for r in res:
            attrs[r[13]] = r[10]
        t2 = time.time()
        t = t2 - t1
        self.stats["nstree"].append(t)
        #self.pp.pprint(attrs)

    def attribute_query_leaf_adjlst(self, node_id):
        """Merge attributes for one leaf, adjacent list model"""
        print "adjlst"
        hierarchy = [node_id]
        attrs = {}
        curr_id = node_id
        t1 = time.time()
        while curr_id != 0:
            q = """SELECT id_parent FROM nodes WHERE id=%s""" % curr_id
            self.curs.execute(q)
            res = self.curs.fetchall()
            curr_id = res[0][0]
            hierarchy.append(curr_id)
        hierarchy.reverse()
        for level in hierarchy:
            q = """SELECT AD.name, A.value FROM nodes N INNER JOIN attributes A ON A.nodes_id=N.id LEFT JOIN attributesDef AD ON A.attributesdef_id=AD.id WHERE N.id=%s""" % level
            self.curs.execute(q)
            res = self.curs.fetchall()
            for r in res:
                attrs[r[0]] = r[1]
        t2 = time.time()
        t = t2 - t1
        self.stats["adjlst"].append(t)
        #self.pp.pprint(attrs)

    def define_node_attribute(self, key, val, id):
        # Search if this attribute is already defined
        q = "SELECT attributesdef.id from attributesdef WHERE name='%s'" % key
        self.curs.execute(q)
        res = self.curs.fetchall()
        attr_id = res[0][0]
        q = "SELECT attributesdef.id from attributes LEFT JOIN attributesdef ON attributesdef.id=attributes.attributesdef_id WHERE nodes_id=%s AND attributesdef.name='%s'" % (id, key)
        self.curs.execute(q)
        res = self.curs.fetchall()
        if len(res) == 0:
            # If not defined, then add it
            q = "INSERT INTO attributes (nodes_id,attributesdef_id,value) VALUES ('%s','%s','%s')" % (id,attr_id,val)
            self.curs.execute(q)
            self.conn.commit()
        else:
            # If already defined, update it
            q = "UPDATE attributes SET value = '%s' where nodes_id = '%s' AND attributesdef_id = '%s'" % (val,id,attr_id)
            self.curs.exectute(q)
            self.conn.commit()

    def get_middle_node(self):
        """Search and return the middle leaf in the tree, assume that nodes are ordered"""
        self.curs.execute("SELECT t1.id FROM nodes AS t1 LEFT JOIN nodes as t2 ON t1.id = t2.id_parent WHERE t2.id IS NULL ORDER BY t2.id")
        rows = self.curs.fetchall()
        pos = len(rows)/2
        return rows[pos][0]

class TestDBManager():

    def __init__(self, dbname, user, password, host):
        self.db = DBManager(dbname, user, password, host)
        self.db.connect()
        self.nb_nodes = 0
        self.deep = 0
        self.nb_iter = 0

    def common(self):
        self.db.delete_all_nodes()
        self.db.delete_all_attributes()
        nb_child = int(round(math.pow(self.nb_nodes,1/float(self.deep))))
        for i in range(self.deep):
            self.db.insert_nodes(nb_child)
        for i in range(self.nb_iter):
            self.db.move_middle_node_to_root()

        self.db.define_node_attribute("SCREEN_01","bidon",0)
        self.db.define_node_attribute("SCREEN_02","coucou",0)
        node = self.db.get_middle_node()
        self.db.define_node_attribute("SCREEN_01","test",node)
        for i in range(self.nb_iter):
            self.db.attribute_query_leaf_nstree(node)
            self.db.attribute_query_leaf_adjlst(node)
        self.db.save_stats("n%s_d%s" % (self.nb_nodes, self.deep))

    def test01(self):
        self.nb_nodes = 1000
        self.deep = 3
        self.nb_iter = 20
        self.common()

    def test02(self):
        self.nb_nodes = 1000
        self.deep = 10
        self.nb_iter = 20
        self.common()
    
if __name__=="__main__":

    from optparse import OptionParser
    p = OptionParser()
    p.add_option("-d", dest="dbname", help="Database name")
    p.add_option("-u", dest="dbuser", help="Username")
    p.add_option("-p", dest="dbpass", help="Password")
    p.add_option("-l", dest="host", help="Hostname")
    
    (o, args) = p.parse_args()
    if o.dbname is None or o.dbuser is None or o.dbpass is None or o.host is None:
        p.print_help()
        sys.exit(1)

    t = TestDBManager(o.dbname, o.dbuser, o.dbpass, o.host)
    #print "test01"
    #t.test01()
    print "test02"
    t.test02()
    print "done"
    
    
    
