#!python
# -*- coding: utf-8 -*-

# =============================================================================
#  Version: 1.0 (December 27, 2019)
#  Author: Pierre Alexis (pierre@lxs.be)
#
#  Contributors:
#   Markus Lindström
#   Gilles Degols
#
# =============================================================================
#  Copyright (c) 2019. Pierre Alexis (pierre@lxs.be).
# =============================================================================

import html
import os
import sys
from unidecode import unidecode
from multiprocessing import Process, current_process
import zmq
import time
from math import ceil
import re

# =============================================================================
# CONFIGURATION 
# =============================================================================

NUM_SOURCES = 4 # Set the number of processes that will parse the XML and extract the articles
NUM_WORKERS = 8 # Set the number of processes that will process the articles to extract words
UNIDECODE_RESULT = True # Do we suppress accents from the result set? ("château" becomes "chateau")
KEEP_ACCENTED_IN_RESULT = False # Do we keep both unidecoded and accented words in the result set? (we keep "château" and "chateau")
MIN_WORD_LENGTH = 4 # Only keep words that have a minimal length
MAX_WORD_LENGTH = 12 # Only keep words that are not longer than a maximal length
EXTRACT_ONLY_TITLE = True # Only extract words from the title

# =============================================================================
# END OF CONFIGURATION 
# =============================================================================

PROXY_BASE_PORT = 4444
CTL_PROXY_IN_PORT = PROXY_BASE_PORT
CTL_PROXY_OUT_PORT  = PROXY_BASE_PORT + 1
DATA_PROXY_IN_PORT = PROXY_BASE_PORT + 2
DATA_PROXY_OUT_PORT  = PROXY_BASE_PORT + 3

SINK_BASE_PORT = 5555
SINK_DATA_PORT = SINK_BASE_PORT
SINK_CTL_PORT  = SINK_BASE_PORT + 1

STATS_BASE_PORT = 8888
STATS_BASE_DATA_PORT = STATS_BASE_PORT + 1



class XMLWikipediaDumpArticleExtractor():
  
  # This class reads a chunk of a Wikipedia Dump and extracts all the articles
  
  def __init__(self, file, start_offset, end_offset, out_sock_data, out_sock_stats, instance):
    
    self.file = file
    self.start_offset = start_offset
    self.end_offset = end_offset
    self.chunk_size = self.end_offset - self.start_offset
    self.out_sock_data = out_sock_data
    self.out_sock_stats = out_sock_stats
    self.instance = instance
    self.sent_article_count = 0
    self.size_already_read = 0
    
    self.out_sock_stats.send_string(f'CTL SOURCE {self.instance} READ_PROGRESS 0')
    self.last_stats_sending_time = time.time()
  
  def read(self):
    self.file.seek(self.start_offset) # move to the start of the chunk
    article = []
    self.sent_article_count = 0
    self.size_already_read = 0
    starting_tag_found = False
    
    if EXTRACT_ONLY_TITLE:
      starting_tag = '<title'
      ending_tag = 'title>'
    else:
      starting_tag = '<text'
      ending_tag = 'text>'
    
    for line in self.file:
      if self.size_already_read >= self.chunk_size and not starting_tag_found: # if we are outside the chunk and that we are outside an article, we can stop reading
        break

      line = line.strip()
      if line.startswith(starting_tag):
        starting_tag_found = True
        article = [line]
      
      if line.endswith(ending_tag):
        article.append(line)
        self._send('\n'.join(article)) 
        starting_tag_found = False
      elif starting_tag_found:
        article.append(line)
      self.size_already_read += len(line)
  
    self.out_sock_stats.send_string(f'CTL SOURCE {self.instance} READ_PROGRESS 100')    
   
  def _send(self, article):

    self.out_sock_data.send_string(article)
    self.sent_article_count += 1
    
    if time.time() - self.last_stats_sending_time > 1:
      self.out_sock_stats.send_string(f'CTL SOURCE {self.instance} READ_PROGRESS {100*self.size_already_read/self.chunk_size}')
      self.last_stats_sending_time = time.time()

class WikipediaArticleWordExtractor():

  # This class extracts all the words from one Wikipedia article
  
  SEND_CHUNK_SIZE = 50000

  def __init__(self, in_sock_data, in_sock_ctl, out_sock_data):

    self.toRemove = "{}'(),.:;|\"*?!"
    self.in_sock_data = in_sock_data
    self.in_sock_ctl = in_sock_ctl
    self.out_sock_data = out_sock_data
    self.name = current_process().name
    self.result_set = set()
   
  def processArticle(self):

    source_count = 0
    at_least_one_source = False
    
    while True:

      # Check if we got control message.
      try:
        control = self.in_sock_ctl.recv_string(flags=zmq.NOBLOCK)
      except zmq.ZMQError as e:
        if e.errno != zmq.EAGAIN:
          raise
      else:
        # We got control message. Break out!
        ctl, type, instance = control.split()
        #print(f'{current_process().name} received control message: {control}')
        if type == 'REGISTER':
          source_count += 1
          at_least_one_source = True
        elif type == 'DONE':
          source_count -= 1
        
      try:
        content = self.in_sock_data.recv_string(flags=zmq.NOBLOCK)
      except zmq.ZMQError as e:
        if e.errno == zmq.EAGAIN:
          # No message in queue.      
          if source_count == 0 and at_least_one_source:
            print(f'{current_process().name} is shutting down.')
            break
          else:
            time.sleep(1)
            continue
        else:
          raise
      else:
        
        content = html.unescape(content)
        content = html.unescape(content)
        
        content = re.sub("{{.*}}", '', content)
        content = re.sub("\[.*\]", '', content)
        content = re.sub("<.*?>", '', content)
        
        for letter in self.toRemove:
          if letter in content:
            content = content.replace(letter,' ')
        splittedContent = content.split()

        for s in splittedContent:
          
          if s not in self.result_set and len(s) >= MIN_WORD_LENGTH and len(s) <= MAX_WORD_LENGTH:
            self.result_set.add(s)        

        if len(self.result_set) >= self.SEND_CHUNK_SIZE:
          self.clean_result_set()
          #print(f'{self.name} sending a chunk.')
          self.out_sock_data.send_json(list(self.result_set))
          self.result_set = set()

    if len(self.result_set) > 0:
      self.clean_result_set()
      #print(f'{self.name} sending final chunk.')
      self.out_sock_data.send_json(list(self.result_set))
   
    
  def clean_result_set(self):
    cleaned_set = set()
    for word in self.result_set:
      # if we have the word "château-fort"
      if '-' in word:
        # we split in "château" and "fort"
        splittedWord = word.split('-')
        are_all_alpha_latin = True
        # we check that all part are alpha latin
        for sw in splittedWord:
          if not self.is_alpha_latin(sw):
            are_all_alpha_latin = False
            break
        if are_all_alpha_latin:
          # we add "château" and "fort" to the set
          for sw in splittedWord:
            lowered_word = sw.lower()
            if UNIDECODE_RESULT:
              unidecoded_word = unidecode(lowered_word)
              cleaned_set.add(unidecoded_word)            
            if not UNIDECODE_RESULT or KEEP_ACCENTED_IN_RESULT:
              cleaned_set.add(lowered_word)
          # we add "château-fort" to the set
          lowered_word = word.lower()
          if UNIDECODE_RESULT:
            unidecoded_word = unidecode(lowered_word)
            cleaned_set.add(unidecoded_word)            
          if not UNIDECODE_RESULT or KEEP_ACCENTED_IN_RESULT:
            cleaned_set.add(lowered_word)
      else:
        if self.is_alpha_latin(word) and not word in cleaned_set:
          lowered_word = word.lower()
          if UNIDECODE_RESULT:
            unidecoded_word = unidecode(lowered_word)
            cleaned_set.add(unidecoded_word)            
          if not UNIDECODE_RESULT or KEEP_ACCENTED_IN_RESULT:
            cleaned_set.add(lowered_word)

    self.result_set = cleaned_set
    return cleaned_set
  
  def is_alpha_latin(self, s):
    unicode_latin_only = True
    for letter in s:
      codepoint = ord(letter)
      if codepoint > 0x17f:
        unicode_latin_only = False
        break
    
    return unicode_latin_only and s.isalpha()
          
  
def stats():
  context = zmq.Context()
  in_sock = context.socket(zmq.SUB)
  in_sock.setsockopt(zmq.SNDHWM, 1000)
  in_sock.setsockopt(zmq.SUBSCRIBE, b'CTL')  
  in_sock.bind(f'tcp://*:{STATS_BASE_DATA_PORT}')

  total_words = 0
  last_total_words = 0
  increase_of_words = 0
  read_progress = {}
  for i in range(NUM_SOURCES):
    read_progress[i] = 0
  
  last_stats_display_time = time.time()
  
  while True:

    # Check if we got control message.
    try:
      control = in_sock.recv_string(flags=zmq.NOBLOCK)
    except zmq.ZMQError as e:
      if e.errno != zmq.EAGAIN:
        raise
      else:
        time.sleep(1)
    else:
      # We got control message. Break out!
      ctl, process, instance, data, value = control.split()
      if data == 'READ_PROGRESS':
        read_progress[int(instance)] = float(value)
      elif data == 'TOTAL_WORDS':
        total_words = int(value)
 
    if time.time() - last_stats_display_time > 5:
      global_read_progress = 0.0
      for k, v in read_progress.items():
        global_read_progress += v
      global_read_progress = global_read_progress / len(read_progress)
      print('Input read progress: {:.2f}%'.format(global_read_progress) + f" - Found words: {total_words} (+ {total_words - last_total_words})")
      last_total_words = total_words
      last_stats_display_time = time.time()
  
  
def ctl_proxy():
  context = zmq.Context()
  in_sock = context.socket(zmq.XSUB)
  in_sock.setsockopt(zmq.SNDHWM, 1000)
  in_sock.bind(f'tcp://*:{CTL_PROXY_IN_PORT}')
  out_sock = context.socket(zmq.XPUB)
  out_sock.setsockopt(zmq.SNDHWM, 1000)
  out_sock.bind(f'tcp://*:{CTL_PROXY_OUT_PORT}')
  zmq.proxy(in_sock, out_sock)
  in_sock.close()
  out_sock.close()
  context.term()
  print("End of CTL Proxy")

def data_proxy():
  context = zmq.Context()
  in_sock = context.socket(zmq.PULL)
  in_sock.setsockopt(zmq.SNDHWM, 1000)
  in_sock.bind(f'tcp://*:{DATA_PROXY_IN_PORT}')
  out_sock = context.socket(zmq.PUSH)
  out_sock.setsockopt(zmq.SNDHWM, 1000)
  out_sock.bind(f'tcp://*:{DATA_PROXY_OUT_PORT}')
  zmq.proxy(in_sock, out_sock)
  in_sock.close()
  out_sock.close()
  context.term()
  print("End of Data Proxy")  
      
      
def worker(instance):
  context = zmq.Context()
  # Creating the incoming data socket
  in_sock_data = context.socket(zmq.PULL)
  in_sock_data.setsockopt(zmq.RCVHWM, 10000)
  in_sock_data.connect(f'tcp://localhost:{DATA_PROXY_OUT_PORT}')
  # Creating the incoming control socket
  in_sock_ctl = context.socket(zmq.SUB)
  in_sock_ctl.setsockopt(zmq.SUBSCRIBE, b'CTL')
  in_sock_ctl.connect(f'tcp://localhost:{CTL_PROXY_OUT_PORT}')
  # Creating the outgoing data socket
  out_sock_data = context.socket(zmq.PUSH)
  out_sock_data.setsockopt(zmq.SNDHWM, 10000)
  out_sock_data.connect(f'tcp://localhost:{SINK_DATA_PORT}')
  # Creating the outgoing control socket
  out_sock_ctl = context.socket(zmq.PUSH)
  out_sock_ctl.connect(f'tcp://localhost:{SINK_CTL_PORT}')
  # FIXME: sleep a little bit to ensure connection is done
  time.sleep(1)
  
  # Send control message to sink
  out_sock_ctl.send_string(f'CTL REGISTER {instance}')
  
  # Let's get going.
  extractor = WikipediaArticleWordExtractor(in_sock_data, in_sock_ctl, out_sock_data)
  extractor.processArticle()
  
  # Send control message to sink
  out_sock_ctl.send_string(f'CTL DONE {instance}')
  
  out_sock_data.close()
  out_sock_ctl.close()
  in_sock_ctl.close()
  in_sock_data.close()
  context.term()
  
def sink(filename):
  context = zmq.Context()
  in_sock_data = context.socket(zmq.PULL)
  in_sock_data.bind(f'tcp://*:{SINK_DATA_PORT}')
  in_sock_data.setsockopt(zmq.RCVHWM, 0)
  in_sock_ctl = context.socket(zmq.PULL)
  in_sock_ctl.bind(f'tcp://*:{SINK_CTL_PORT}')
  # Creating the socket for stats
  out_sock_stats = context.socket(zmq.PUB)
  out_sock_stats.connect(f'tcp://localhost:{STATS_BASE_DATA_PORT}')  

  result_set = set()
  worker_count = 0
  at_least_one_worker = False
  
  while True:
    # Check for control messages
    try:
      control = in_sock_ctl.recv_string(flags=zmq.NOBLOCK)
    except zmq.ZMQError as e:
      if e.errno != zmq.EAGAIN:
        raise
    else:
      ctl, type, instance = control.split()
      #print(f'{current_process().name} received control message: {control}')
      if type == 'REGISTER':
        worker_count += 1
        at_least_one_worker = True
      elif type == 'DONE':
        worker_count -= 1
      continue
    
    try:
      words = in_sock_data.recv_json(flags=zmq.NOBLOCK)
    except zmq.ZMQError as e:
      if e.errno == zmq.EAGAIN:
        # No message in queue
        if worker_count == 0 and at_least_one_worker:
          #print(f'{current_process().name} is shutting down.')
          break
        time.sleep(1)
      else:
        raise
    else:
      words = set(words)
      result_set = result_set.union(words)
      out_sock_stats.send_string(f'CTL SINK 0 TOTAL_WORDS {len(result_set)}')
  
  print("Sorting the results and saving to output file...")
  result = sorted(result_set)
  if EXTRACT_ONLY_TITLE:
    output_file_name = f'{filename}_title_output.txt'
  else:
    output_file_name = f'{filename}_content_output.txt'
  with open(output_file_name, 'w', encoding='utf-8', newline='') as out:
    for i in result:
      out.write(i)
      out.write('\n')
       
  print(f'Total number of words found: {len(result)}')

  in_sock_data.close()
  in_sock_ctl.close()
  out_sock_stats.close()
  context.term()

  
def source(instance, filename):
  start_time = time.time()
  context = zmq.Context()
  # Creating the socket for data
  out_sock_data = context.socket(zmq.PUSH)
  out_sock_data.setsockopt(zmq.SNDHWM, 10000)
  out_sock_data.connect(f'tcp://localhost:{DATA_PROXY_IN_PORT}')
  # Creating the socket for control
  out_sock_ctl = context.socket(zmq.PUB)
  out_sock_ctl.connect(f'tcp://localhost:{CTL_PROXY_IN_PORT}')
  # Creating the socket for stats
  out_sock_stats = context.socket(zmq.PUB)
  out_sock_stats.connect(f'tcp://localhost:{STATS_BASE_DATA_PORT}')  
  # FIXME: sleep a little bit to ensure connection is done
  time.sleep(2)

  # Send control message to workers
  out_sock_ctl.send_string(f'CTL REGISTER {instance}')
  
  file = open(filename, 'r', encoding="utf-8")
  file_size = os.fstat(file.fileno()).st_size
  chunk_size = ceil(file_size / NUM_SOURCES)
  start_offset = chunk_size * instance 
  end_offset = int(min(file_size, chunk_size * (instance + 1)))  

  xmlParser = XMLWikipediaDumpArticleExtractor(file, start_offset, end_offset, out_sock_data, out_sock_stats, instance)
  xmlParser.read()
  
  # Send control message to workers
  out_sock_ctl.send_string(f'CTL DONE {instance}')
  
  out_sock_data.close()
  out_sock_ctl.close()
  out_sock_stats.close()
  context.term()
  
def run(filename):
  start_time = time.time()
  
  print("Setting up the workers...")
  ctl_proxy_process = Process(target=ctl_proxy, name='CTL Proxy')
  ctl_proxy_process.start()
  
  data_proxy_process = Process(target=data_proxy, name='Data Proxy')
  data_proxy_process.start()
  
  stats_process = Process(target=stats, name='Stats')
  stats_process.start()

  sink_process = Process(target=sink, args=(filename,), name='Sink')
  sink_process.start()
  
  workers = []
  for i in range(NUM_WORKERS):
    worker_process = Process(target=worker, args=(i,), name=f'Worker-{i}')
    workers.append(worker_process)
    worker_process.start()
  
  sources = []
  for i in range(NUM_SOURCES):
    source_process = Process(target=source, args=(i, filename,), name='Source')
    sources.append(source_process)
    source_process.start()
  
  for source_process in sources:
    source_process.join()  

  for worker_process in workers:
    worker_process.join()
  
  sink_process.join()
  ctl_proxy_process.terminate()
  data_proxy_process.terminate()
  stats_process.terminate()
  
  print("The extraction took {:.2f} seconds.".format(time.time() - start_time))

if __name__ == '__main__':
  run(sys.argv[1])