#!/usr/bin/env python
#-*- coding: UTF-8 -*-

""" This program loads a underworlds world, and display its
3D scene.

Based on:
- pygame code from http://3dengine.org/Spectator_%28PyOpenGL%29
- http://www.lighthouse3d.com/tutorials
- http://www.songho.ca/opengl/gl_transform.html
- http://code.activestate.com/recipes/325391/
- ASSIMP's C++ SimpleOpenGL viewer
"""
import sys
import argparse

import logging
logger = logging.getLogger("underworlds.3d_viewer")
gllogger = logging.getLogger("OpenGL")
gllogger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)

import OpenGL
#OpenGL.ERROR_CHECKING=False
#OpenGL.ERROR_LOGGING = False
#OpenGL.ERROR_ON_COPY = True
#OpenGL.FULL_LOGGING = True
from OpenGL.GL import *
from OpenGL.error import GLError
from OpenGL.GLU import *
#from OpenGL.GLUT import *
from OpenGL.arrays import vbo
from OpenGL.GL import shaders

import pygame
import pygame.font
import pygame.image

import math, random
import numpy
from numpy import linalg

import underworlds
from underworlds.types import *
from underworlds.helpers.geometry import transform, get_scene_bounding_box, get_world_transform
from underworlds.helpers import transformations

ROTATION_180_X = numpy.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]], dtype=numpy.float32)

#rendering mode
BASE = "BASE"
COLORS = "COLORS"
SILHOUETTE = "SILHOUETTE"
HELPERS = "HELPERS"

FLAT_VERTEX_SHADER="""
#version 130

uniform mat4 u_viewProjectionMatrix;
uniform mat4 u_modelMatrix;

uniform vec4 u_materialDiffuse;

in vec3 a_vertex;

out vec4 v_color;

void main(void)
{
    v_color = u_materialDiffuse;
    gl_Position = u_viewProjectionMatrix * u_modelMatrix * vec4(a_vertex, 1.0);
}
"""


BASIC_VERTEX_SHADER="""
#version 130

uniform mat4 u_viewProjectionMatrix;
uniform mat4 u_modelMatrix;
uniform mat3 u_normalMatrix;
uniform vec3 u_lightPos;

uniform vec4 u_materialDiffuse;

in vec3 a_vertex;
in vec3 a_normal;

out vec4 v_color;

void main(void)
{
    // Now the normal is in world space, as we pass the light in world space.
    vec3 normal = u_normalMatrix * a_normal;

    float dist = distance(a_vertex, u_lightPos);

    // go to https://www.desmos.com/calculator/nmnaud1hrw to play with the parameters
    // att is not used for now
    float att=1.0/(1.0+0.8*dist*dist);

    vec3 surf2light = normalize(u_lightPos - a_vertex);
    vec3 norm = normalize(normal);
    float dcont=max(0.0,dot(norm,surf2light));

    float ambient = 0.3;
    float intensity = dcont + 0.3 + ambient;

    v_color = u_materialDiffuse  * intensity;

    gl_Position = u_viewProjectionMatrix * u_modelMatrix * vec4(a_vertex, 1.0);
}
"""

BASIC_FRAGMENT_SHADER="""
#version 130

in vec4 v_color;

void main() {
    gl_FragColor = v_color;
}
"""

GOOCH_VERTEX_SHADER="""
#version 130

// attributes
in vec3 a_vertex; // xyz - position
in vec3 a_normal; // xyz - normal

// uniforms
uniform mat4 u_modelMatrix;
uniform mat4 u_viewProjectionMatrix;
uniform mat3 u_normalMatrix;
uniform vec3 u_lightPos;
uniform vec3 u_camPos;

// output data from vertex to fragment shader
out vec3 o_normal;
out vec3 o_lightVector;

///////////////////////////////////////////////////////////////////

void main(void)
{
   // transform position and normal to world space
   vec4 positionWorld = u_modelMatrix * vec4(a_vertex, 1.0);
   vec3 normalWorld = u_normalMatrix * a_normal;

   // calculate and pass vectors required for lighting
   o_lightVector = u_lightPos - positionWorld.xyz;
   o_normal = normalWorld;

   // project world space position to the screen and output it
   gl_Position = u_viewProjectionMatrix * positionWorld;
}
"""

GOOCH_FRAGMENT_SHADER="""
#version 130

// data from vertex shader
in vec3 o_normal;
in vec3 o_lightVector;

// diffuse color of the object
uniform vec4 u_materialDiffuse;
// cool color of gooch shading
uniform vec3 u_coolColor;
// warm color of gooch shading
uniform vec3 u_warmColor;
// how much to take from object color in final cool color
uniform float u_alpha;
// how much to take from object color in final warm color
uniform float u_beta;

// output to framebuffer
out vec4 resultingColor;

///////////////////////////////////////////////////////////

void main(void)
{
   // normlize vectors for lighting
   vec3 normalVector = normalize(o_normal);
   vec3 lightVector = normalize(o_lightVector);
   // intensity of diffuse lighting [-1, 1]
   float diffuseLighting = dot(lightVector, normalVector);
   // map intensity of lighting from range [-1; 1] to [0, 1]
   float interpolationValue = (1.0 + diffuseLighting)/2;

   //////////////////////////////////////////////////////////////////

   // cool color mixed with color of the object
   vec3 coolColorMod = u_coolColor + vec3(u_materialDiffuse) * u_alpha;
   // warm color mixed with color of the object
   vec3 warmColorMod = u_warmColor + vec3(u_materialDiffuse) * u_beta;
   // interpolation of cool and warm colors according
   // to lighting intensity. The lower the light intensity,
   // the larger part of the cool color is used
   vec3 colorOut = mix(coolColorMod, warmColorMod, interpolationValue);

   //////////////////////////////////////////////////////////////////

   // save color
   resultingColor.rgb = colorOut;
   resultingColor.a = 1;
} 
"""

SILHOUETTE_VERTEX_SHADER="""
#version 130

in vec3 a_vertex; // xyz - position
in vec3 a_normal; // xyz - normal

uniform mat4 u_modelMatrix;
uniform mat4 u_viewProjectionMatrix;
uniform mat4 u_modelViewMatrix;
uniform vec4 u_materialDiffuse;
uniform float u_bordersize; // width of the border

out vec4 v_color;

void main(void){
   v_color = u_materialDiffuse;
   float distToCamera = -(u_modelViewMatrix * vec4(a_vertex, 1.0)).z;
   vec4 tPos   = vec4(a_vertex + a_normal * u_bordersize * distToCamera, 1.0);
   gl_Position = u_viewProjectionMatrix * u_modelMatrix * tPos;
}
"""
DEFAULT_CLIP_PLANE_NEAR = 0.001
DEFAULT_CLIP_PLANE_FAR = 1000.0

class DefaultCamera():
    def __init__(self, w, h, fov):
        self.name="default camera"
        self.type=CAMERA
        self.clipplanenear = DEFAULT_CLIP_PLANE_NEAR
        self.clipplanefar = DEFAULT_CLIP_PLANE_FAR
        self.aspect = w/h
        self.horizontalfov = fov * math.pi/180
        self.transformation = numpy.array([[ 0.68, -0.32, 0.65, 7.48],
                               [ 0.73,  0.31, -0.61, -6.51],
                               [-0.01,  0.89,  0.44,  5.34],
                               [ 0.,    0.,    0.,    1.  ]], dtype=numpy.float32)

        self.transformation = numpy.dot(self.transformation, ROTATION_180_X)
    def __str__(self):
        return self.name

class Underworlds3DViewer:

    base_name = "Underworlds 3D viewer"

    def __init__(self, ctx, world, w=1024, h=768):

        self.w = w
        self.h = h

        pygame.init()
        self.base_name = self.base_name + " <%s>" % world
        pygame.display.set_caption(self.base_name)
        pygame.display.set_mode((w,h), pygame.OPENGL | pygame.DOUBLEBUF)

        glClearColor(0.18,0.18,0.18,1.0)

        self.prepare_shaders()

        self.ctx = ctx
        self.world = ctx.worlds[world]

        self.scene = None
        self.meshes = {} # stores the OpenGL vertex/faces/normals buffers pointers

        self.node2colorid = {} # stores a color ID for each node. Useful for mouse picking and visibility checking
        self.colorid2node = {} # reverse dict of node2colorid

        self.currently_selected = None
        self.moving = False
        self.moving_situation = None

        self.default_camera = DefaultCamera(self.w,self.h,fov=70)
        self.cameras = [self.default_camera]

        self.current_cam_index = 0
        self.current_cam = self.default_camera
        self.set_camera_projection()

        self.load_world()

        # for FPS computation
        self.frames = 0
        #self.last_fps_time = glutGet(GLUT_ELAPSED_TIME)

        # user interactions
        self.focal_point = [0,0,0]
        self.is_rotating = False
        self.is_panning = False
        self.is_zooming = False

    def prepare_shaders(self):

        ### Base shader
        vertex = shaders.compileShader(BASIC_VERTEX_SHADER, GL_VERTEX_SHADER)
        fragment = shaders.compileShader(BASIC_FRAGMENT_SHADER, GL_FRAGMENT_SHADER)

        self.shader = shaders.compileProgram(vertex,fragment)

        self.set_shader_accessors( ('u_modelMatrix',
                                    'u_viewProjectionMatrix',
                                    'u_normalMatrix',
                                    'u_lightPos',
                                    'u_materialDiffuse'), 
                                    ('a_vertex',
                                      'a_normal'), self.shader)

        ### Flat shader
        flatvertex = shaders.compileShader(FLAT_VERTEX_SHADER, GL_VERTEX_SHADER)
        self.flatshader = shaders.compileProgram(flatvertex,fragment)

        self.set_shader_accessors( ('u_modelMatrix',
                                    'u_viewProjectionMatrix',
                                    'u_materialDiffuse',), 
                                    ('a_vertex',), self.flatshader)


        ### Silhouette shader
        silh_vertex = shaders.compileShader(SILHOUETTE_VERTEX_SHADER, GL_VERTEX_SHADER)
        self.silhouette_shader = shaders.compileProgram(silh_vertex,fragment)

        self.set_shader_accessors( ('u_modelMatrix',
                                    'u_viewProjectionMatrix',
                                    'u_modelViewMatrix',
                                    'u_materialDiffuse',
                                    'u_bordersize'  # width of the silhouette
                                    ), 
                                    ('a_vertex',
                                      'a_normal'), self.silhouette_shader)


        ### Gooch shader
        gooch_vertex = shaders.compileShader(GOOCH_VERTEX_SHADER, GL_VERTEX_SHADER)
        gooch_fragment = shaders.compileShader(GOOCH_FRAGMENT_SHADER, GL_FRAGMENT_SHADER)
        self.gooch_shader = shaders.compileProgram(gooch_vertex, gooch_fragment)

        self.set_shader_accessors( ('u_modelMatrix',
                                    'u_viewProjectionMatrix',
                                    'u_normalMatrix',
                                    'u_lightPos',
                                    'u_materialDiffuse',
                                    'u_coolColor',
                                    'u_warmColor',
                                    'u_alpha',
                                    'u_beta'
                                    ), 
                                    ('a_vertex',
                                      'a_normal'), self.gooch_shader)

    def set_shader_accessors(self, uniforms, attributes, shader):
        # add accessors to the shaders uniforms and attributes
        for uniform in uniforms:
            location = glGetUniformLocation( shader,  uniform )
            if location in (None,-1):
                raise RuntimeError('No uniform: %s (maybe it is not used '
                                   'anymore and has been optimized out by'
                                   ' the shader compiler)'%( uniform ))
            setattr( shader, uniform, location )

        for attribute in attributes:
            location = glGetAttribLocation( shader, attribute )
            if location in (None,-1):
                raise RuntimeError('No attribute: %s'%( attribute ))
            setattr( shader, attribute, location )

    def prepare_gl_buffers(self, id):

        meshes = self.meshes

        if id in meshes:
            # mesh already loaded. Fine
            return

        meshes[id] = {}

        # leave some time for new nodes to push their meshes
        if not self.ctx.has_mesh(id):
            logger.warning("Mesh ID %s is not available on the server... "
                           "waiting for it..." % id)
            while not self.ctx.has_mesh(id):
                time.sleep(0.01)

            logger.info("Mesh ID %s is now available. Getting it..." % id)

        logger.info("Fetching mesh <%s>" % id)
        mesh = self.ctx.mesh(id) # retrieve the mesh from the server

        # Fill the buffer for vertex and normals positions
        v = numpy.array(mesh.vertices, 'f')
        n = numpy.array(mesh.normals, 'f')

        #meshes[id]["vbo"] = glGenBuffers(1)
        #glBindBuffer(GL_ARRAY_BUFFER, meshes[id]["vbo"])
        #glBufferData(GL_ARRAY_BUFFER, 
        #            numpy.hstack((v,n)), # concatenate vertices and normals in one array
        #            GL_STATIC_DRAW)

        meshes[id]["vbo"] = vbo.VBO(numpy.hstack((v,n)))

        # Fill the buffer for vertex positions
        meshes[id]["faces"] = glGenBuffers(1)
        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, meshes[id]["faces"])
        glBufferData(GL_ELEMENT_ARRAY_BUFFER, 
                    numpy.array(mesh.faces, dtype=numpy.int32),
                    GL_STATIC_DRAW)

        meshes[id]["nbfaces"] = len(mesh.faces)
        meshes[id]["material"] = {}
        meshes[id]["material"]["diffuse"] = mesh.diffuse

        # Unbind buffers
        glBindBuffer(GL_ARRAY_BUFFER,0)
        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER,0)

    def get_rgb_from_colorid(self, colorid):
        r = (colorid >> 0) & 0xff
        g = (colorid >> 8) & 0xff
        b = (colorid >> 16) & 0xff

        return (r,g,b)

    def get_color_id(self):
        id = random.randint(0, 256*256*256)
        if id not in self.colorid2node:
            return id
        else:
            return self.get_color_id()

    def is_mesh_updated(self, node):
        if hasattr(node, "cad"):
            return node.glmeshes != node.cad
        elif hasattr(node, "lowres"):
            return node.glmeshes != node.lowres
        elif hasattr(node, "hires"):
            return node.glmeshes != node.hires
        else:
            return False

    def glize(self, node):

        node.selected = True if self.currently_selected and self.currently_selected == node else False

        if node.type == MESH:
            colorid = self.get_color_id()
            self.colorid2node[colorid] = node
            self.node2colorid[node] = colorid

            if hasattr(node, "cad"):
                node.glmeshes = node.cad
            elif hasattr(node, "lowres"):
                node.glmeshes = node.lowres
            elif hasattr(node, "hires"):
                node.glmeshes = node.hires
            else:
                raise StandardError("The node %s has no mesh available!" % node.name)

            for mesh in node.glmeshes:
                self.prepare_gl_buffers(mesh)

        elif node.type == CAMERA:
            logger.info("Added camera <%s>" % node.name)
            logger.info("Camera position: %.3f, %.3f, %.3f" % tuple(node.transformation[:,3][:3].tolist()))
            self.cameras.append(node)

    def load_world(self):
        logger.info("Preparing world <%s> for 3D rendering..." % self.world)

        scene = self.scene = self.world.scene
        nodes = scene.nodes
        for node in nodes:
            logger.info("Loading node <%s>" % node)
            self.glize(node)

        #log some statistics
        logger.info("  -> %d nodes" % len(nodes))

        scene_aabb = get_scene_bounding_box(scene)

        if not scene_aabb:
            logger.warning("Nothing to display. Leaving now.")
            sys.exit(0)

        self.bb_min, self.bb_max = scene_aabb
        logger.info("  -> scene bounding box:" + str(self.bb_min) + " - " + str(self.bb_max))

        logger.info("World <%s> ready for 3D rendering." % self.world)

    def cycle_cameras(self):

        self.current_cam_index = (self.current_cam_index + 1) % len(self.cameras)
        self.current_cam = self.cameras[self.current_cam_index]
        self.set_camera_projection(self.current_cam)
        logger.info("Switched to camera <%s>" % self.current_cam)

    def set_overlay_projection(self):
        glViewport(0,0,self.w,self.h)
        glMatrixMode(GL_PROJECTION)
        glLoadIdentity()
        glOrtho(0.0, self.w - 1.0, 0.0, self.h - 1.0, -1.0, 1.0)
        glMatrixMode(GL_MODELVIEW)
        glLoadIdentity()

    def set_camera_projection(self, camera = None):

        if not camera:
            camera = self.current_cam

        znear = camera.clipplanenear or DEFAULT_CLIP_PLANE_NEAR
        zfar = camera.clipplanefar or DEFAULT_CLIP_PLANE_FAR
        aspect = camera.aspect
        fov = camera.horizontalfov

        glMatrixMode(GL_PROJECTION)
        glLoadIdentity()

        # Compute gl frustrum
        tangent = math.tan(fov/2.)
        h = znear * tangent
        w = h * aspect

        # params: left, right, bottom, top, near, far
        glFrustum(-w, w, -h, h, znear, zfar)
        # equivalent to:
        #gluPerspective(fov * 180/math.pi, aspect, znear, zfar)

        self.projection_matrix = glGetFloatv( GL_PROJECTION_MATRIX).transpose()

        glMatrixMode(GL_MODELVIEW)
        glLoadIdentity()


    def render_colors(self):

        glEnable(GL_DEPTH_TEST)
        glDepthFunc(GL_LEQUAL)

        glPolygonMode(GL_FRONT_AND_BACK, GL_FILL)
        glEnable(GL_CULL_FACE)

        glUseProgram(self.flatshader)

        glUniformMatrix4fv( self.flatshader.u_viewProjectionMatrix, 1, GL_TRUE,
                            numpy.dot(self.projection_matrix,self.view_matrix))


        self.recursive_render(self.scene.rootnode, self.flatshader, mode=COLORS)

        glUseProgram( 0 )

    def get_hovered_node(self, mousex, mousey):
        """
        Attention: The performances of this method relies heavily on the size of the display!
        """
        
        # mouse out of the window?
        if mousex < 0 or mousex >= self.w or mousey < 0 or mousey >= self.h:
            return None

        self.render_colors()
        # Capture image from the OpenGL buffer
        buf = ( GLubyte * (3 * self.w * self.h) )(0)
        glReadPixels(0, 0, self.w, self.h, GL_RGB, GL_UNSIGNED_BYTE, buf)

        #Reinterpret the RGB pixel buffer as a 1-D array of 24bits colors
        a = numpy.ndarray(len(buf), numpy.dtype('>u1'), buf)
        colors = numpy.zeros(len(buf) / 3, numpy.dtype('<u4'))
        for i in range(3):
            colors.view(dtype='>u1')[i::4] = a.view(dtype='>u1')[i::3]

        colorid = colors[mousex + mousey * self.w]

        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        if colorid in self.colorid2node:
            return self.colorid2node[colorid]


    def check_visibility(self):
        """
        Attention: The performances of this method relies heavily on the size of the display!
        """
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        self.render_colors()
        # Capture image from the OpenGL buffer
        buf = ( GLubyte * (3 * self.w * self.h) )(0)
        glReadPixels(0, 0, self.w, self.h, GL_RGB, GL_UNSIGNED_BYTE, buf)

        #Reinterpret the RGB pixel buffer as a 1-D array of 24bits colors
        a = numpy.ndarray(len(buf), numpy.dtype('>u1'), buf)
        colors = numpy.zeros(len(buf) / 3, numpy.dtype('<u4'))
        for i in range(3):
            colors.view(dtype='>u1')[i::4] = a.view(dtype='>u1')[i::3]

        colors = colors[numpy.nonzero(colors)] #discard black background
        
        if colors.any():
            bins = numpy.bincount(colors)
            ii = numpy.nonzero(bins)[0]

            for i in ii:
                print ("Node %s is visible (%d pix)" % (self.colorid2node[i], bins[i]))
        else:
            print("Nothing visible!")

        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

    def screenshot(self, filename = "screenshot.png"):

        logger.info("Taking a screenshot...")
        import Image
        # Capture image from the OpenGL buffer
        buffer = ( GLubyte * (3 * self.w * self.h) )(0)
        glReadPixels(0, 0, self.w, self.h, GL_RGB, GL_UNSIGNED_BYTE, buffer)

        # Use PIL to convert raw RGB buffer and flip the right way up
        image = Image.frombytes(mode="RGB", size=(self.w, self.h), data=buffer)
        image = image.transpose(Image.FLIP_TOP_BOTTOM)

        # Save image to disk
        image.save(filename)
        logger.info("...done. Image saved as <%s>." % filename)


    def render(self, wireframe = False, twosided = False):

        glEnable(GL_DEPTH_TEST)
        glDepthFunc(GL_LEQUAL)


        glPolygonMode(GL_FRONT_AND_BACK, GL_LINE if wireframe else GL_FILL)
        glDisable(GL_CULL_FACE) if twosided else glEnable(GL_CULL_FACE)

        self.render_grid()

        self.recursive_render(self.scene.rootnode, None, mode=HELPERS)

        ### First, the silhouette

        if False:
            shader = self.silhouette_shader

            #glDepthMask(GL_FALSE)
            glCullFace (GL_FRONT) # cull front faces

            glUseProgram(shader)
            glUniform1f( shader.u_bordersize, 0.01)

            glUniformMatrix4fv( shader.u_viewProjectionMatrix, 1, GL_TRUE,
                                numpy.dot(self.projection_matrix,self.view_matrix))


            self.recursive_render(self.scene.rootnode, shader, mode=SILHOUETTE)

            glUseProgram( 0 )

        ### Then, inner shading
        #glDepthMask(GL_TRUE)
        glCullFace (GL_BACK)

        use_gooch = False
        if use_gooch:
            shader = self.gooch_shader

            glUseProgram(shader)
            glUniform3f( shader.u_lightPos, -.5,-.5,.5 )

            ##### GOOCH specific
            glUniform3f( shader.u_coolColor, 159.0/255, 148.0/255, 255.0/255)
            glUniform3f( shader.u_warmColor, 255.0/255, 75.0/255, 75.0/255)
            glUniform1f( shader.u_alpha, .25)
            glUniform1f( shader.u_beta, .25)
            #########
        else:
            shader = self.shader
            glUseProgram(shader)
            glUniform3f( shader.u_lightPos, -.5,-.5,.5 )


        glUniformMatrix4fv( shader.u_viewProjectionMatrix, 1, GL_TRUE,
                            numpy.dot(self.projection_matrix,self.view_matrix))


        self.recursive_render(self.scene.rootnode, shader)

        glUseProgram( 0 )

    def draw_gizmo(self,x,y,z,size,orientation=(1,1,1)):
        """ Must be called inside a glBegin()/glEnd()!
        """

        # draw line for x axis
        glVertex3f(x, y, z)
        glVertex3f(x + orientation[0] * size, y, z)

        # draw line for y axis
        glVertex3f(x, y, z)
        glVertex3f(x, y + orientation[1] * size, z)

        # draw line for Z axis
        glVertex3f(x, y, z)
        glVertex3f(x, y, z + orientation[2] * size)


    def render_axis(self,
                    transformation = numpy.identity(4, dtype=numpy.float32),
                    bb=None,
                    label=None,
                    size = 0.2,
                    selected=False):
        m = transformation.transpose() # OpenGL row major

        glPushMatrix()
        glMultMatrixf(m)

        glLineWidth(3 if selected else 1)

        size = 2 * size if selected else size

        glBegin(GL_LINES)

        # draw line for x axis
        glColor3f(1.0, 0.0, 0.0)
        glVertex3f(0.0, 0.0, 0.0)
        glVertex3f(size, 0.0, 0.0)

        # draw line for y axis
        glColor3f(0.0, 1.0, 0.0)
        glVertex3f(0.0, 0.0, 0.0)
        glVertex3f(0.0, size, 0.0)

        # draw line for Z axis
        glColor3f(0.0, 0.0, 1.0)
        glVertex3f(0.0, 0.0, 0.0)
        glVertex3f(0.0, 0.0, size)

        glEnd()

        if label:
            self.showtext(label)

        # bounding box
        if bb:
            glLineWidth(2 if selected else 1)
            glBegin(GL_LINES)

            if selected:
                glColor3f(.5, .5, .5)
            else:
                glColor3f(.4, .4, .4)

            bb_min, bb_max = bb
            x_min, y_min, z_min = bb_min
            x_max, y_max, z_max = bb_max


            self.draw_gizmo(*bb_min,size=size)
            self.draw_gizmo(*bb_max,size=size,orientation=(-1,-1,-1))
            self.draw_gizmo(x_min,y_max,z_max,size=size,orientation=(1,-1,-1))
            self.draw_gizmo(x_min,y_min,z_max,size=size,orientation=(1,1,-1))
            self.draw_gizmo(x_max,y_min,z_max,size=size,orientation=(-1,1,-1))
            self.draw_gizmo(x_max,y_min,z_min,size=size,orientation=(-1,1,1))
            self.draw_gizmo(x_max,y_max,z_min,size=size,orientation=(-1,-1,1))
            self.draw_gizmo(x_min,y_max,z_min,size=size,orientation=(1,-1,1))



            glEnd()

        glPopMatrix();

    def render_camera(self, camera, transformation):

        m = transformation.transpose() # OpenGL row major

        aspect = camera.aspect

        u = 0.1 # unit size (in m)
        l = 3 * u # lenght of the camera cone
        f = 3 * u # aperture of the camera cone

        glPushMatrix()
        glMultMatrixf(m)

        glLineWidth(2)
        glBegin(GL_LINE_STRIP)

        glColor3f(.2, .2, .2)

        glVertex3f(u, u, -u)
        glVertex3f(u, -u, -u)
        glVertex3f(-u, -u, -u)
        glVertex3f(-u, u, -u)
        glVertex3f(u, u, -u)

        glVertex3f(u, u, 0.0)
        glVertex3f(u, -u, 0.0)
        glVertex3f(-u, -u, 0.0)
        glVertex3f(-u, u, 0.0)
        glVertex3f(u, u, 0.0)

        glVertex3f(f * aspect, f, l)
        glVertex3f(f * aspect, -f, l)
        glVertex3f(-f * aspect, -f, l)
        glVertex3f(-f * aspect, f, l)
        glVertex3f(f * aspect, f, l)

        glEnd()

        glBegin(GL_LINE_STRIP)
        glVertex3f(u, -u, -u)
        glVertex3f(u, -u, 0.0)
        glVertex3f(f * aspect, -f, l)
        glEnd()

        glBegin(GL_LINE_STRIP)
        glVertex3f(-u, -u, -u)
        glVertex3f(-u, -u, 0.0)
        glVertex3f(-f * aspect, -f, l)
        glEnd()

        glBegin(GL_LINE_STRIP)
        glVertex3f(-u, u, -u)
        glVertex3f(-u, u, 0.0)
        glVertex3f(-f * aspect, f, l)
        glEnd()

        glPopMatrix();


    def render_grid(self):

        glLineWidth(1)
        glColor3f(0.5, 0.5, 0.5)
        glBegin(GL_LINES)
        for i in range(-10,11):
            glVertex3f(i, -10.0, 0.0)
            glVertex3f(i, 10.0, 0.0)

        for i in range(-10, 11):
            glVertex3f(-10.0, i, 0.0)
            glVertex3f(10.0, i, 0.0)
        glEnd()


    def recursive_render(self, node, shader, mode = BASE, with_normals=True):
        """ Main recursive rendering method.
        """

        normals = with_normals
        ambient = True
        if mode == COLORS:
            normals = False
            ambient = False

        try:
            m = get_world_transform(self.scene, node)
        except AttributeError:
            #probably a new incoming node, that has not yet been converted to numpy
            self.glize(node)
            m = get_world_transform(self.scene, node)

        # HELPERS mode
        ###
        if mode == HELPERS:
            #if node.type == ENTITY:
            self.render_axis(m,
                             bb=node.aabb,
                             label=node.name if node != self.scene.rootnode else None,
                             selected=node.selected if hasattr(node, "selected") else False)

            if node.type == CAMERA:
                self.render_camera(node, m)

            for child in node.children:
                try:
                    self.recursive_render(self.scene.nodes[child], shader, mode)
                except KeyError as ke:
                    logger.warning("Node ID %s listed as child of %s, but it"
                                   " does not exist! Skipping it" % (child, repr(node)))
            return

        # Mesh rendering modes
        ###
        if node.type == MESH:

            # if the node has been recently turned into a mesh, we might not
            # have the mesh data yet. Likewise, the mesh might have been updated.
            if not hasattr(node, "glmeshes") or self.is_mesh_updated(self.scene.nodes[node.id]):
                self.glize(node)

            for id in node.glmeshes:

                stride = 24 # 6 * 4 bytes

                if node.selected and mode == SILHOUETTE:
                    glUniform4f( shader.u_materialDiffuse, 1.0, 0.0, 0.0, 1.0 )
                    glUniformMatrix4fv( shader.u_modelViewMatrix, 1, GL_TRUE,
                                        numpy.dot(self.view_matrix,m))

                else:
                    if mode == COLORS:
                            colorid = self.node2colorid[node]
                            r,g,b= self.get_rgb_from_colorid(colorid)
                            glUniform4f( shader.u_materialDiffuse, r/255.0,g/255.0,b/255.0,1.0 )
                    elif mode == SILHOUETTE:
                        glUniform4f( shader.u_materialDiffuse, .0, .0, .0, 1.0 )
                    else:
                        if node.selected:
                            diffuse = (1.0,0.0,0.0,1.0) # selected nodes in red
                        else:
                            diffuse = self.meshes[id]["material"]["diffuse"]
                        if len(diffuse) == 3: # RGB instead of expected RGBA
                            diffuse.append(1.0)
                        glUniform4f( shader.u_materialDiffuse, *diffuse )
                        #if ambient:
                        #    glUniform4f( shader.Material_ambient, *mat["ambient"] )


                if mode == BASE: # not in COLORS or SILHOUETTE
                    normal_matrix = linalg.inv(numpy.dot(self.view_matrix, m)[0:3,0:3]).transpose()
                    glUniformMatrix3fv( shader.u_normalMatrix, 1, GL_TRUE, normal_matrix )

                glUniformMatrix4fv( shader.u_modelMatrix, 1, GL_TRUE, m )

                vbo = self.meshes[id]["vbo"]
                vbo.bind()

                glEnableVertexAttribArray( shader.a_vertex )
                if normals:
                    glEnableVertexAttribArray( shader.a_normal )

                glVertexAttribPointer(
                    shader.a_vertex,
                    3, GL_FLOAT,False, stride, vbo
                )

                if normals:
                    glVertexAttribPointer(
                        shader.a_normal,
                        3, GL_FLOAT,False, stride, vbo+12
                    )

                glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.meshes[id]["faces"])
                glDrawElements(GL_TRIANGLES, self.meshes[id]["nbfaces"] * 3, GL_UNSIGNED_INT, None)


                vbo.unbind()
                glDisableVertexAttribArray( shader.a_vertex )

                if normals:
                    glDisableVertexAttribArray( shader.a_normal )


                glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, 0)

        for child in node.children:
            try:
                self.recursive_render(self.scene.nodes[child], shader, mode)
            except KeyError as ke:
                logger.warning("Node ID %s listed as child of %s, but it"
                                " does not exist! Skipping it" % (child, repr(node)))


    def switch_to_overlay(self):
        glPushMatrix()
        self.set_overlay_projection()

    def switch_from_overlay(self):
        self.set_camera_projection()
        glPopMatrix()

    def select_node(self, node):
        self.currently_selected = node
        for n in self.scene.nodes:
            if n is node:
                n.selected = True
            else:
                n.selected = False

    def loop(self):

        pygame.display.flip()

        if not self.process_events():
            return False # ESC has been pressed

        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        # Compute FPS
        #gl_time = glutGet(GLUT_ELAPSED_TIME)
        #self.frames += 1
        #if gl_time - self.last_fps_time >= 1000:
        #    current_fps = self.frames * 1000 / (gl_time - self.last_fps_time)
        #    pygame.display.set_caption(self.base_name + " - %.0f fps" % current_fps)
        #    self.frames = 0
        #    self.last_fps_time = gl_time


        return True



    def process_events(self):


        LEFT_BUTTON=1
        MIDDLE_BUTTON=2
        RIGHT_BUTTON=3
        WHEEL_UP=4
        WHEEL_DOWN=5

        dx, dy = pygame.mouse.get_rel()
        mousex, mousey = pygame.mouse.get_pos()

        zooming_one_shot = False

        ok = True

        for evt in pygame.event.get():
            if evt.type == pygame.MOUSEBUTTONDOWN and evt.button == LEFT_BUTTON:
                hovered = self.get_hovered_node(mousex, self.h - mousey)
                if hovered:
                    if self.currently_selected and self.currently_selected == hovered:
                        self.select_node(None)
                    else:
                        logger.info("Node %s selected" % hovered)
                        self.select_node(hovered)
                else:
                    self.is_rotating=True
            if evt.type == pygame.MOUSEBUTTONUP and evt.button == LEFT_BUTTON:
                self.is_rotating = False

            if evt.type == pygame.MOUSEBUTTONDOWN and evt.button == MIDDLE_BUTTON:
                self.is_panning=True
            if evt.type == pygame.MOUSEBUTTONUP and evt.button == MIDDLE_BUTTON:
                self.is_panning = False

            if evt.type == pygame.MOUSEBUTTONDOWN and evt.button == RIGHT_BUTTON:
                self.is_zooming=True
            if evt.type == pygame.MOUSEBUTTONUP and evt.button == RIGHT_BUTTON:
                self.is_zooming = False

            if evt.type == pygame.MOUSEBUTTONDOWN and evt.button in [WHEEL_UP, WHEEL_DOWN]:
                zooming_one_shot=True
                self.is_zooming=True
                dy = -10 if evt.button == WHEEL_UP else 10

            if evt.type == pygame.KEYDOWN:
                ok = (ok and self.process_keystroke(evt.key, evt.mod))

        self.controls_3d(dx, dy, zooming_one_shot)

        return ok

    def process_keystroke(self, key, mod):

        # process arrow keys if an object is selected
        if self.currently_selected:
            up = 0
            strafe = 0

            if key == pygame.K_UP:
                up = 1
            if key == pygame.K_DOWN:
                up = -1
            if key == pygame.K_LEFT:
                strafe = -1
            if key == pygame.K_RIGHT:
                strafe = 1

            self.move_selected_node(up, strafe)

        if key == pygame.K_f:
            pygame.display.toggle_fullscreen()

        if key == pygame.K_s:
            self.screenshot()

        if key == pygame.K_v:
            self.check_visibility()

        if key == pygame.K_TAB:
            self.cycle_cameras()

        if key in [pygame.K_ESCAPE, pygame.K_q]:
            return False

        return True

    def controls_3d(self, dx, dy, zooming_one_shot = False):

        CAMERA_TRANSLATION_FACTOR = 0.01
        CAMERA_ROTATION_FACTOR = 0.01

        if not (self.is_rotating or self.is_panning or self.is_zooming):
            return

        current_pos = self.current_cam.transformation[:3, 3].copy()
        distance = numpy.linalg.norm(self.focal_point-current_pos)

        if self.is_rotating:
            """ Orbiting the camera is implemented the following way:

            - the rotation is split into a rotation around the *world* Z axis
              (controlled by the horizontal mouse motion along X) and a
              rotation around the *X* axis of the camera (pitch) *shifted to
              the focal origin* (the world origin for now). This is controlled
              by the vertical motion of the mouse (Y axis).

            - as a result, the resulting transformation of the camera in the
              world frame C' is:
                C' = (T · Rx · T⁻¹ · (Rz · C)⁻¹)⁻¹

              where:
                - C is the original camera transformation in the world frame,
                - Rz is the rotation along the Z axis (in the world frame)
                - T is the translation camera -> world (ie, the inverse of the
                  translation part of C
                - Rx is the rotation around X in the (translated) camera frame
            """

            rotation_camera_x = dy * CAMERA_ROTATION_FACTOR
            rotation_world_z = dx * CAMERA_ROTATION_FACTOR
            world_z_rotation = transformations.euler_matrix(0,0,rotation_world_z)
            cam_x_rotation = transformations.euler_matrix(rotation_camera_x,0,0)

            after_world_z_rotation = numpy.dot(world_z_rotation,self.current_cam.transformation)

            inverse_transformation = transformations.inverse_matrix(after_world_z_rotation)

            translation = transformations.translation_matrix(transformations.decompose_matrix(inverse_transformation)[3])
            inverse_translation = transformations.inverse_matrix(translation)

            new_inverse = numpy.dot(inverse_translation,inverse_transformation)
            new_inverse = numpy.dot(cam_x_rotation,new_inverse)
            new_inverse = numpy.dot(translation,new_inverse)

            self.current_cam.transformation = transformations.inverse_matrix(new_inverse).astype(numpy.float32)

        if self.is_panning:
            tx = -dx * CAMERA_TRANSLATION_FACTOR * distance
            ty = dy * CAMERA_TRANSLATION_FACTOR * distance
            cam_transform = transformations.translation_matrix((tx, ty, 0)).astype(numpy.float32)
            self.current_cam.transformation = numpy.dot(self.current_cam.transformation, cam_transform)

        if self.is_zooming:
            tz = dy * CAMERA_TRANSLATION_FACTOR * distance
            cam_transform = transformations.translation_matrix((0, 0, tz)).astype(numpy.float32)
            self.current_cam.transformation = numpy.dot(self.current_cam.transformation, cam_transform)


        if zooming_one_shot:
            self.is_zooming = False


        # Only broadcast the camera transformation if it is not
        # our 'local' default camera
        #if self.current_cam != self.default_camera:
        #    self.scene.nodes.update(self.current_cam)

        self.update_view_camera()


    def update_view_camera(self):

        # Update the camera position from the server if it is not
        # our 'local' camera
        if self.current_cam != self.default_camera:
            self.current_cam = self.scene.nodes[self.current_cam.id]

        if self.current_cam != self.default_camera:
            self.view_matrix = linalg.inv(get_world_transform(self.scene,self.current_cam))
        else:
            self.view_matrix = linalg.inv(self.current_cam.transformation)

        # Rotate by 180deg around X to have Z pointing backward (OpenGL convention)
        self.view_matrix = numpy.dot(ROTATION_180_X, self.view_matrix)

        glMatrixMode(GL_MODELVIEW)
        glLoadIdentity()
        glMultMatrixf(self.view_matrix.transpose())

    def move_selected_node(self, up, strafe):
        self.currently_selected.transformation[0][3] += strafe
        self.currently_selected.transformation[2][3] += up

        self.scene.nodes.update(self.currently_selected)


    def showtext(self, text, x=0, y=0, z=0, size=20):

        #TODO: alpha blending does not work...
        #glEnable(GL_BLEND)
        #glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)

        font = pygame.font.Font (None, size)
        textSurface = font.render(text, True, (10,10,10,255), 
                                              (255*0.18,255*0.18,255*0.18,0))
        textData = pygame.image.tostring(textSurface, "RGBA", True)
        glRasterPos3d(x,y,z)
        glDrawPixels(textSurface.get_width(),
                     textSurface.get_height(),
                     GL_RGBA, GL_UNSIGNED_BYTE, 
                     textData)

        #glDisable(GL_BLEND)

def main(world, host, port):

    with underworlds.Context("3D viewer", host=host, port=port) as ctx:
        app = Underworlds3DViewer(ctx, world=world)


        clock = pygame.time.Clock()

        while app.loop():

            app.update_view_camera()

            ## Main rendering
            app.render()

            ## GUI text display
            app.switch_to_overlay()
            app.showtext("world <%s> -- camera <%s>"% (app.world, str(app.current_cam)), 10, app.h - 30)
            if app.currently_selected:
                app.showtext("node <%s>"% app.currently_selected, 10, app.h - 50)
                pos = app.h - 70

                app.showtext("(%sm, %sm, %sm)" % (app.currently_selected.transformation[0,3],
                                                  app.currently_selected.transformation[1,3],
                                                  app.currently_selected.transformation[2,3]), 30, pos)
                pos -= 20
                for key, value in app.currently_selected.properties.items():

                    app.showtext("%s: %s" % (key, value), 30, pos)
                    pos -= 20

            app.switch_from_overlay()

            # Make sure we do not go over 30fps
            clock.tick(30)

        logger.info("Closing underworlds context...")

    logger.info("Quitting! Bye bye!")

#########################################################################
#########################################################################

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("world", help="underworlds world to display")
    parser.add_argument("--host", help="underworlds daemon host (default: localhost)", default="localhost")
    parser.add_argument("--port", help="underworlds daemon port (default: 50051)", default=50051, type=int)
    args = parser.parse_args()

    main(args.world, args.host, args.port)


