#!/usr/bin/env python

__authors__ = "Martin Sandve Alnes"
__date__ = "2008-09-04 -- 2008-09-04"

import unittest

from sfc.codegeneration.codeformatting import CodeFormatter, gen_token_declarations, gen_token_definitions

test_switch_result = """switch(facet)
{
case 0:
    dofs[0] = 2;
    dofs[1] = 0;
    dofs[2] = 1;
    break;
case 1:
    dofs[0] = 5;
    dofs[1] = 3;
    dofs[2] = 4;
    break;
case 2:
    dofs[0] = 6;
    dofs[1] = 7;
    dofs[2] = 8;
    break;
default:
    throw std::runtime_error("Invalid facet number.");
}"""

gen_tokens_result = """{
    double s1;
    double s2;
}
        double s1 = e1;
        double s2 = e2;
    double s1 = e1;
    double s2 = e2;"""

functions_result = """inline void c(double a[3], double b[3], double c[3]) const;

inline void c(double a[3], double b[3], double c[3]) const
{
    // Empty body!
}

c(a, b, c);
"""

class CodeFormattingTest(unittest.TestCase):

    def setUp(self):
        pass
    
    def _compare_codes(self, code, correct):
        "Compare codes and print codes if this fails."
        if code != correct:
            print "Failure, got code:"
            print '"""%s"""' % code
            print "but expecting:"
            print '"""%s"""' % correct
        self.assertTrue(code == correct)
    
    def test_switch(self):
        code = CodeFormatter()
        code.begin_switch("facet")
        facet_dofs = [(2, 0, 1), (5, 3, 4), (6, 7, 8)]
        for i, dofs in enumerate(facet_dofs):
            code.begin_case(i)
            for j, d in enumerate(dofs):
                code += "dofs[%d] = %d;" % (j, d)
            code.end_case()
        code += "default:"
        code.indent()
        code += 'throw std::runtime_error("Invalid facet number.");'
        code.dedent()
        code.end_switch()
        code = str(code)
        
        self._compare_codes(code, test_switch_result)
    
    def test_gen_tokens(self):
        code = CodeFormatter()
        class MockObject:
            def __init__(self, text):
                self._text = text
            def printc(self):
                return self._text
            def __str__(self):
                return self._text
        s1 = MockObject("s1")
        e1 = MockObject("e1")
        s2 = MockObject("s2")
        e2 = MockObject("e2")
        tokens = [(s1, e1), (s2, e2)]
        code.begin_block()
        code += gen_token_declarations(tokens)
        code.end_block()
        code.indent()
        code.indent()
        code += gen_token_definitions(tokens)
        code.dedent()
        code += gen_token_definitions(tokens)
        code.dedent()
        code = str(code)

        self._compare_codes(code, gen_tokens_result)

    def test_functions(self):
        code = CodeFormatter()
        
        name = "myfunction"

        argnames = ["a", "b", "c"]
        args = [("double", name, "[3]") for name in argnames]
        
        code.declare_function(name, args=args, const=True, inline=True)
        code.new_line("")

        body = "// Empty body!"
        code.define_function(name, args=args, const=True, inline=True, body=body)
        code.new_line("")

        code.call_function(name, args=argnames)
        code.new_line("")
        
        code = str(code)

        self._compare_codes(code, functions_result)

tests = [CodeFormattingTest]

if __name__ == "__main__":
    unittest.main()

