#!/usr/bin/env python
import pdfplumber
import numpy as np
import pandas as pd
import cv2
import tesserocr #wrapper for tesseract, https://github.com/sirfz/tesserocr
import csv
import sys
from PIL import ImageDraw, ImageOps
import skimage.draw
import skimage.transform
import skimage.feature
import skimage.filters
import matplotlib.pyplot as plt
import matplotlib
import math
import argparse

def read_cl():
    parser = argparse.ArgumentParser()
    parser.add_argument("--page", type=int, help="page number (one indexed)")
    parser.add_argument("--debug", action="store_true", help="print line information")
    parser.add_argument("-f", "--file", help="print line information")
    return parser.parse_args()


#TODO: change to singleton
OCR_API = tesserocr.PyTessBaseAPI() #psm=tesserocr.PSM.OSD_ONLY, oem=tesserocr.OEM.LSTM_ONLY) as api:

def get_pdf_lines(page):
    #kinda adapted from here: https://scraperwiki.com/2012/06/pdf-table-extraction-of-a-table/
    #note: image tools use top left as the origin and pdfplumber cropping uses the top left as the origin
    #*but* pdfplumber lines and rectangles have origin in the bottom-left
    #so using page_height - y in place of y below to convert to a top left origin system
    page_height = float(page.height)

    segments = []
    for line in page.lines:
        for coord in ["x0","y0","x1","y1"]:
            line[coord] = float(line[coord])
        if line["x0"] == line["x1"] or line["y0"] == line["y1"]:
            segments.append(((line["x0"],page_height-line["y0"]),(line["x1"],page_height-line["y1"])))
        else:
            print("sloped line",l)
    for rect in page.rects:
        for coord in ["x0","y0","x1","y1"]:
            rect[coord] = float(rect[coord])
        width = rect["x1"] - rect["x0"]
        height = rect["y1"] - rect["y0"]
        if width > 2.0:
            top = ((rect["x0"],page_height-rect["y0"]),(rect["x1"],page_height-rect["y0"]))
            bottom = ((rect["x0"],page_height-rect["y1"]),(rect["x1"],page_height-rect["y1"]))
            segments.append(top)
            segments.append(bottom)
        if height > 2.0:
            left = ((rect["x0"],page_height-rect["y0"]),(rect["x0"],page_height-rect["y1"]))
            right = ((rect["x1"],page_height-rect["y0"]),(rect["x1"],page_height-rect["y1"]))
            segments.append(left)
            segments.append(right)
    return list(set(segments))

def is_col(s):
    return abs(s[1][1] - s[0][1]) > 10 * abs(s[1][0] - s[0][0])

def is_row(s):
    return 10 * abs(s[1][1] - s[0][1]) < abs(s[1][0] - s[0][0])

def filter_axial_segments(segments):
    return [s for s in segments if is_col(s) or is_row(s)]

def point_to_segment(pt, segment, clamp=True):
    #distance from a point to a segment
    #http://paulbourke.net/geometry/pointlineplane/
    p1,p2 = segment
    u = ((pt[0] - p1[0])*(p2[0] - p1[0]) + (pt[1] - p1[1])*(p2[1] - p1[1])) / (point_to_point(p1,p2)**2)
    if clamp:
        u = min(max(u,0),1) #cap at 0 and 1
    nearest_pt = (p1[0] + u * (p2[0] - p1[0]), p1[1] + u * (p2[1] - p1[1]))
    return point_to_point(pt, nearest_pt)

def point_to_point(p1, p2):
    return ((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2) ** 0.5

def segment_distance(s1, s2):
    #compute a measure of distance between two segments
    #such that you
    #(want to ensure that two collinear segments with a lot of space between have high distance
    #and that two long segments that form a right angle have high distance, but )
    p1,p2 = s1
    p3,p4 = s2

    pairs = [(p1,s2),(p2,s2),(p3,s1),(p4,s1)]
    #clamped distances are the closest points to the line segment (ie unextended)
    #and unclamped distances are the closest to the full line
    #for two segments to be similar at least two points must be close
    #to the opposite segment and at least three points must be close to the
    #opposite line
    clamped_distances = sorted([point_to_segment(pt,seg) for pt,seg in pairs])
    unclamped_distances = sorted([point_to_segment(pt,seg,False) for pt,seg in pairs])
    return max(clamped_distances[1],unclamped_distances[2])

def merge_two_segments(s1, s2):
    #given two similar segments, merge them
    #keeping the points that are "farther apart"
    p1,p2 = s1
    p3,p4 = s2

    d1 = (p2[0] - p1[0], p2[1] - p1[1])
    d2 = (p4[0] - p3[0], p4[1] - p3[1])
    dot = d1[0]*d2[0] + d1[1]*d2[1]
    if dot < 0:
        #flip second segment
        p3,p4 = p4,p3

    x1 = max([(point_to_segment(p1,s2),p1), (point_to_segment(p3,s1),p3)])[1]
    x2 = max([(point_to_segment(p2,s2),p2), (point_to_segment(p4,s1),p4)])[1]

    return (x1, x2)

def combine_segments(segments, margin):
    #remove similar segments
    input_ = segments
    while True:
        output = []
        for s1 in input_:
            match = False
            for i,s2 in enumerate(output):
                if segment_distance(s1,s2) < margin:
                    match = True
                    output[i] = merge_two_segments(s1,s2)
            if not match:
                output.append(s1)
        if len(output) == len(input_):
            break
        input_ = output
    return output

def segments_to_components(segments):
    #compute connected components formed by the input lines
    #naive n^2 for now, though there's apparently an ~n^4/3 method
    components = []
    for s in segments:
        first_match = None
        for i,c in enumerate(components):
            for s2 in c:
                if segment_intersect_margin(s, s2, 10):
                    if first_match is None:
                        first_match = i
                    else:
                        #add to first_match cluster
                        components[first_match].update(c)
                        components[i] = set()
                    break
        if first_match is not None:
            components[first_match].add(s)
        else:
            components.append(set([s]))
        components = [c for c in components if c] #clear out empty components
    return components

def component_to_cells(component):
    all_cols = sorted([s for s in component if is_col(s)], key = lambda x: x[0][0])
    all_rows = sorted([s for s in component if is_row(s)], key = lambda x: x[0][1])

    # print("all_cols, all_rows")
    # print(all_cols)
    # print(all_rows)

    #mappings from intersections to lines and vice-versa
    rows_to_int = {}
    cols_to_int = {}
    int_to_lines = {}

    int_top = set()
    int_bot = set()
    int_left = set()
    int_right = set()

    for row_idx, row in enumerate(all_rows):
        for col_idx, col in enumerate(all_cols):
            rows_to_int.setdefault(row_idx,[])
            cols_to_int.setdefault(col_idx,[])
            intersect = segment_intersect_margin(row, col, 10)
            if intersect:
                row_x = sorted([row[0][0],row[1][0]])
                col_y = sorted([col[0][1],col[1][1]])
                int_x = intersect[0]
                int_y = intersect[1]
                if int_x < row_x[0] + 10:
                    int_left.add(intersect)
                if int_x > row_x[1] - 10:
                    int_right.add(intersect)
                if int_y < col_y[0] + 10:
                    int_top.add(intersect)
                if int_y > col_y[1] - 10:
                    int_bot.add(intersect)

                rows_to_int[row_idx].append(intersect)
                cols_to_int[col_idx].append(intersect)
                int_to_lines[intersect] = [row_idx,col_idx]

    cells = {}
    for intersect in int_to_lines:
        row_idx, col_idx = int_to_lines[intersect]
        next_cols = sorted([x for x in rows_to_int[row_idx] if x[0] > intersect[0] and x not in int_bot and x not in int_left], key = lambda x: x[0])
        next_rows = sorted([x for x in cols_to_int[col_idx] if x[1] > intersect[1] and x not in int_top and x not in int_right], key = lambda x: x[1])
        if next_cols and next_rows:
            cells[(row_idx,col_idx)] = {"x": intersect[0], "y": intersect[1], "w": next_cols[0][0] - intersect[0], "h": next_rows[0][1] - intersect[1]}
    return cells

def on_segment(p, q, r):
    if ( (q[0] <= max(p[0], r[0])) and (q[0] >= min(p[0], r[0])) and
         (q[1] <= max(p[1], r[1])) and (q[1] >= min(p[1], r[1]))):
        return True
    return False

def ccw(p, q, r):
    return np.sign((q[0] - p[0]) * (r[1] - q[1]) - (q[1] - p[1]) * (r[0] - q[0]))

def segment_intersect(line1, line2):
    p1,q1 = line1
    p2,q2 = line2

    o1 = ccw(p1, q1, p2)
    o2 = ccw(p1, q1, q2)
    o3 = ccw(p2, q2, p1)
    o4 = ccw(p2, q2, q1)

    #no collinearity:
    if ((o1 != o2) and (o3 != o4)):
        return True

    # Special Cases:
    # p1, q1, and p2 are collinear and p2 lies on segment p1q1
    if ((o1 == 0) and on_segment(p1, p2, q1)):
        return True

    # p1, q1, and q2 are collinear and q2 lies on segment p1q1
    if ((o2 == 0) and on_segment(p1, q2, q1)):
        return True

    # p2, q2, and p1 are collinear and p1 lies on segment p2q2
    if ((o3 == 0) and on_segment(p2, p1, q2)):
        return True

    # p2, q2, and q1 are collinear and q1 lies on segment p2q2
    if ((o4 == 0) and on_segment(p2, q1, q2)):
        return True

    return False

def subtract_points(p1, p2):
    return (p1[0] - p2[0], p1[1] - p2[1])

def cross_product(p1, p2):
    return p1[0]*p2[1] - p1[1]*p2[0]

def segment_intersect2(segment1, segment2):
    #from: https://github.com/pgkelley4/line-segments-intersect/blob/master/js/line-segments-intersect.js
    #via: https://stackoverflow.com/questions/563198/how-do-you-detect-where-two-line-segments-intersect
    p1,p2 = segment1
    q1,q2 = segment2

    r = subtract_points(p2,p1)
    s = subtract_points(q2,q1)

    u_numerator = cross_product(subtract_points(q1,p1), r)
    denominator = cross_product(r, s)

    if denominator == 0:
        if u_numerator == 0:
            #all four points collinear
            if on_segment(q1, p1, q2):
                return p1
            elif on_segment(q1, p2, q2):
                return p2
            elif on_segment(p1, q1, p2):
                return q1
            elif on_segment(p1, q2, p2):
                return q2
            else:
                #not overlapping
                return None
        else:
            #segments are parallel
            return None
    u = u_numerator / denominator
    t = cross_product(subtract_points(q1,p1),s) / denominator

    if 0 <= u <= 1 and 0 <= t <= 1:
        return (p1[0] + t*r[0],p1[1] + t*r[1])
    else:
        return None


def extend_segment(segment, ratio=None, length=None):
    p1,q1 = segment
    #extend segment up or down around its midpoint by a given fraction or amount
    if length:
        segment_length = point_to_point(p1,q1)
        if segment_length == 0:
            ratio = 1 #can't extend a segment with no length
        else:
            ratio = (segment_length + length) / segment_length
    elif ratio:
        pass
    else:
        raise
    mid = ((p1[0]+q1[0])/2,(p1[1]+q1[1])/2)
    p1 = ((p1[0]-mid[0])*ratio + mid[0], (p1[1]-mid[1])*ratio + mid[1])
    q1 = ((q1[0]-mid[0])*ratio + mid[0], (q1[1]-mid[1])*ratio + mid[1])
    return (p1,q1)

def segment_intersect_margin(segment1, segment2, margin):
    s1 = extend_segment(segment1, length=margin)
    s2 = extend_segment(segment2, length=margin)
    assert(1*segment_intersect(s1,s2) == 1*(segment_intersect2(s1,s2) is not None)) #TODO: remove once confident about segment_intersect2 working
    return segment_intersect2(s1, s2)

def line_detection3(pil_img):
    #this time checking directly for vertical and horizontal lines
    image = np.array(ImageOps.invert(pil_img.convert("L")))
    edges = skimage.feature.canny(image, 2, 1, 25)
    edges = skimage.filters.gaussian(edges, sigma=1) #blur slightly -- currently rounding errors causing issues
    height, width = image.shape

    lines = []
    for i in range(width):
        lines.append([(i,j) for j in range(height)])
    for j in range(height):
        lines.append([(i,j) for i in range(width)])

    markov = LineDetectionMarkov()

    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    ax = axes.ravel()

    ax[0].imshow(edges, cmap=matplotlib.cm.gray)
    ax[0].set_title('Input image')
    ax[0].set_axis_off()

    ax[1].imshow(image, cmap=matplotlib.cm.gray)
    ax[1].set_ylim((image.shape[0], 0))
    ax[1].set_axis_off()
    ax[1].set_title('Detected lines')

    segments = []
    start = None
    last = None
    for l in lines:
        pixel_vals = [1 * (edges[x[1]][x[0]] > 0.3) for x in l]
        line_state_vals = markov.get_mle_states(pixel_vals)
        for state,coord in zip(line_state_vals, l):
            if state == 1:
                if start is None:
                    start = coord
            elif state == 0:
                if start is not None:
                    print("detected: ", (start, last))
                    ax[1].plot([start[0],last[0]], [start[1],last[1]], 'g')
                    segments.append((start, last))
                    start = None
            else:
                raise
            last = coord
        if start:
            print("detected: ", (start, last))
            ax[1].plot([start[0],last[0]], [start[1],last[1]], 'g')
            segments.append((start, coord))

    plt.tight_layout()
    plt.show()

    return segments


def line_detection2(pil_img):
    #HoughLinesP doesn't work as well as I'd like
    #in particular it seems to leave out some solid vertical + horizontal lines
    #also once I increased maxLineGap, which ought to only increase the number of output
    #lines, and it dropped some of my lines (why?)

    #idea: post-process hough line detection with a simple markov model:
    #every hough line has segments that are "on" and segments that are "off"
    #(if the entire line is "off" then it was a false positive)
    #there's some cost of transitioning between off and on, some cost of
    #seeing an edge pixel when "off" or not seeing an edge pixel when "on"
    #and then you find the maximum likelihood sequence of states
    #which generates a set of segments


    # https://scikit-image.org/docs/dev/auto_examples/edges/plot_line_hough_transform.html
    # Classic straight-line Hough transform
    # Set a precision of 0.5 degree.
    image = np.array(ImageOps.invert(pil_img.convert("L")))
    height, width = image.shape
    edges = skimage.feature.canny(image, 2, 1, 25)

    #blurring seems to create too many outputs from houghlines
    #edges = skimage.filters.gaussian(edges, sigma=1) #blur slightly -- currently rounding errors causing issues

    tested_angles = np.linspace(-np.pi / 2, np.pi / 2, 360, endpoint=False)
    h, theta, d = skimage.transform.hough_line(edges, theta=tested_angles)


    # Generating figure 1
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    ax = axes.ravel()

    ax[0].imshow(edges, cmap=matplotlib.cm.gray)
    ax[0].set_title('Input image')
    ax[0].set_axis_off()

    ax[1].imshow(image, cmap=matplotlib.cm.gray)
    ax[1].set_ylim((image.shape[0], 0))
    ax[1].set_axis_off()
    ax[1].set_title('Detected lines')

    markov = LineDetectionMarkov()

    segments = []
    for _, angle, dist in zip(*skimage.transform.hough_line_peaks(h, theta, d)):
        (x0, y0) = dist * np.array([np.cos(angle), np.sin(angle)])
        slope = np.tan(angle + np.pi/2)

        #compute intersections with each edge of the image
        pt1 = (0, y0 - x0*slope)
        pt2 = (width-1, y0 - (x0-width+1)*slope)
        pt3 = (x0 - y0/(slope+1e-10), 0)
        pt4 = (x0 - (y0-height+1)/(slope+1e-10), height-1)

        endpoints = []
        if 0 <= pt1[1] <= height:
            endpoints.append(pt1)
        if 0 <= pt2[1] <= height:
            endpoints.append(pt2)
        if 0 <= pt3[0] <= width:
            endpoints.append(pt3)
        if 0 <= pt4[0] <= width:
            endpoints.append(pt4)

        endpoints = list(set(endpoints)) #should handle the literal corner case
        if len(endpoints) != 2: raise

        start, end = endpoints

        ax[1].plot([endpoints[0][0],endpoints[1][0]], [endpoints[0][1],endpoints[1][1]], 'g')

        #reverse coords and round for skimage.draw.line below
        line_floor = ((math.floor(start[1]), math.floor(start[0])), (math.floor(end[1]), math.floor(end[0])))
        line_ceil = ((math.ceil(start[1]), math.ceil(start[0])), (math.ceil(end[1]), math.ceil(end[0])))
        line_round = ((round(start[1]), round(start[0])), (round(end[1]), round(end[0])))

        #https://stackoverflow.com/a/57345267
        discrete_line_floor = skimage.draw.line(*line_floor[0], *line_floor[1])
        discrete_line_ceil = skimage.draw.line(*line_ceil[0], *line_ceil[1])
        discrete_line_round = skimage.draw.line(*line_round[0], *line_round[1])

        pixel_vals = [1*(p0>0.3 or p1>0.3) for p0, p1 in zip(edges[discrete_line_floor],edges[discrete_line_ceil])]
        line_state_vals = markov.get_mle_states(pixel_vals)

        start = None
        last = None
        line_coords = list(zip(*discrete_line_round))
        line_coords = [(x[1],x[0]) for x in line_coords] #flip coords to return to (x,y) format
        for state,coord in zip(line_state_vals, line_coords):
            if state == 1:
                if start is None:
                    start = coord
            elif state == 0:
                if start is not None:
                    segments.append((start, last))
                    start = None
            else:
                raise
            last = coord
        if start:
            segments.append((start, coord))

        for s in segments:
            print("detected: ",s)
            ax[1].plot([s[0][0],s[1][0]], [s[0][1],s[1][1]], 'g')

    plt.tight_layout()
    plt.show()

    return segments


class BasicMarkov():
    def __init__(self, transition_probs, emission_probs):
        self.transition_probs = transition_probs
        self.emission_probs = emission_probs
    def get_mle_states(self,obs):
        states = [x for x in self.emission_probs]
        n = len(states)
        probs = {s:math.log(1/n) for s in states} #prob of the most likely history for each state
        hists = {s:[] for s in states} #most likely history for each state

        for x in obs:
            new_probs = {}
            new_hists = {}
            for s_new in states:
                ml_prob, ml_state = max([(probs[s_old] + self.transition_probs[s_old][s_new] + self.emission_probs[s_new][x], s_old) for s_old in states])
                new_probs[s_new] = ml_prob
                new_hists[s_new] = hists[ml_state] + [s_new]
            probs = new_probs
            hists = new_hists
        _, ml_state = max([(probs[state], state) for state in probs])
        return hists[ml_state]

class LineDetectionMarkov(BasicMarkov):
    def __init__(self):
        transition_probs = {
            0: {0: math.log(0.999), 1: math.log(0.001)},
            1: {0: math.log(0.001), 1: math.log(0.999)}
        }
        emission_probs = {
            0: {0: math.log(0.3), 1: math.log(0.7)},
            1: {0: math.log(0.1), 1: math.log(0.9)}
        }
        super(LineDetectionMarkov,self).__init__(transition_probs, emission_probs)


def line_detection(pil_img):
    #Couldn't get very good line detection going, notes here:
    #(1) Regular HoughLines isn't ideal -- it basically searches for lines in the
    #document with the most points on them, even if the points aren't consecutive. That
    #means it'll pick up more false positives in large documents / when the table is smaller

    #(2) HoughLinesP doesn't work as well as I'd like, though it's decent
    #in particular it seems to leave out some solid vertical + horizontal lines
    #also once I increased maxLineGap, which ought to only increase the number of output
    #lines, and it dropped some of my lines (why?)

    #(3) Standard best practice seems to be to run an edge detection algorithm before
    #running HoughLines eg canny. This can be better than just running the image itself
    #(inverted) in case of eg darker headers which came up in the first few I tried.
    #Dropbox notes that canny didn't work great for them:
    #https://dropbox.tech/machine-learning/fast-and-accurate-document-detection-for-scanning
    #and it sounds like they rolled their own edge detection:
    #https://news.ycombinator.com/item?id=12257394
    #Better edge detection could be the best area for improvement. How about running a set
    #of local line filters in various directions?

    #(4) With the HoughLines/HoughLinesP it's sometimes recommended to blur before
    # running HoughLines, to help with off-by-one errors when
    #however, this adds a lot of points to the document and then the HoughLinesP
    #output can blow up. Seems a little worse on net? Maybe you could make it work

    #Idea #1: post-process regular hough line detection with a simple markov model:
    #every hough line has segments that are "on" and segments that are "off"
    #(if the entire line is "off" then it was a false positive)
    #there's some cost of transitioning between off and on, some cost of
    #seeing an edge pixel when "off" or not seeing an edge pixel when "on"
    #and then you find the maximum likelihood sequence of states
    #which generates a set of segments

    #Idea #2: forget HoughLines entirely and individually check every
    #vertical and horizontal line with the markov model from idea #1
    #This was very slow and still not great because of the imperfect
    #canny edge detection

    #For now sticking with the basic HoughLinesP version.
    #If I want to improve this I would try to build some better edge
    #detection. Also maybe use scikit-image for houghlines instead of
    #cv2, which feels a little outdated


    #takes a PIL-style image
    #https://stackoverflow.com/a/39925278
    DEBUG_LINES = True

    #pil_img_gray = pil_img.convert("L") #convert to grayscale
    #cv2_img = cv2.imread("/tmp/pdf_image-05.png")

    #https://stackoverflow.com/a/14140796
    cv2_img = np.array(pil_img)[:, :, ::-1].copy() #convert to cv2

    #blurring is often recommended, don't really think it helps
    #though, turning it off for now:
    # kernel = np.ones((5, 5), np.float32) / 25
    # cv2_img = cv2.filter2D(cv2_img, -1, kernel)

    #edge detection may be the best place for improvement, see notes above
    #line below just inverts the image and usese that, but it seemed to work worse
    edges = cv2.Canny(cv2_img,50,150,apertureSize = 3) #not sure if canny better than just inverting, see below
    #edges = np.array(ImageOps.invert(pil_img.convert("L")))

    if DEBUG_LINES:
        cv2.imwrite('/tmp/edges-50-150.jpg',edges)


    #quick summary of HoughLinesP since it was tough to find on the web
    #Naive line detection: try to look for lines in the image that contain a lot of edge pixels.
    #Idea: convert to a dual space and then conduct the search. In the dual space all points
    #are converted to lines/curves, and we look for the points that are hit by the most lines/curves.

    #params (not 100% on this):
    #rho = pixel granularity of which lines to search for -- 1 is standard
    #theta = angle granularity (in radians) of which lines to search for -- pi/180 is standard
    #threshold = minimum number of points that must lie on the line
    #minLineLength = minimum length of the line segment to be included
    #maxLineGap = maximum number of spaces in the middle of the line for inclusion
    threshold = 1000
    raw_ = cv2.HoughLinesP(
        image=edges,
        rho=1,
        theta=np.pi/180,
        threshold=threshold,
        lines=np.array([]),
        minLineLength=threshold,
        maxLineGap=80
    )

    if raw_ is None:
        segments = [] #no lines detected
    else:
        segments =  [((x[0][0],x[0][1]),(x[0][2],x[0][3])) for x in raw_]

    if DEBUG_LINES:
        for s in segments:
            cv2.line(cv2_img, s[0], s[1], (0, 0, 255), 3, cv2.LINE_AA)
        cv2.imwrite('/tmp/lines.jpg',cv2_img)

    return segments

def filter_low_intersections(segments):
    #filter out any segments that don't intersect three others
    #as those shouldn't
    while True:
        next_ = []
        for s in segments:
            cnt = 0
            for s2 in segments:
                if segment_intersect_margin(s,s2,10) is not None:
                    cnt += 1
            if cnt >= 4:
                next_.append(s)
        if len(next_) == len(segments):
            return next_
        else:
            segments = next_


def component_to_df(component, bbox_reader, debug_fn=None):
    #bbox_reader is a function to pull the text from a given bounding box

    component = filter_low_intersections(component) #remove lines with only a couple intersections
    cells = component_to_cells(component)

    if len(cells) == 0:
        return None

    row_cnt = max([x[0] for x in cells]) + 1
    col_cnt = max([x[1] for x in cells]) + 1

    if (row_cnt - 1) * (col_cnt - 1) <= 1:
        return None

    if debug_fn:
        debug_fn(component)

    df = pd.DataFrame()
    for row_idx in range(row_cnt):
        row = []
        for col_idx in range(col_cnt):
            if (row_idx, col_idx) in cells:
                box = cells[(row_idx,col_idx)]
                row.append(bbox_reader(box))
            else:
                row.append("")
        df = df.append(pd.Series(row), ignore_index=True)
    return df

def image_bbox_text(pil_img, bbox):
    OCR_API.SetImage(pil_img)
    OCR_API.SetRectangle(bbox['x'], bbox['y'], bbox['w'], bbox['h'])
    return OCR_API.GetUTF8Text()

def pdf_image_to_dfs(pdf_page, pdf_image, debug=False):
    page_height = pdf_page.height
    #pdfplumber objects have origin in the bottom left, while pdfplumber crop and image tools have origin
    #in the top left -> use page_height - y when cropping. also see pdf_lines
    image_bbox = (pdf_image['x0'], page_height - pdf_image['y1'], pdf_image['x1'], page_height - pdf_image['y0'])
    cropped_page = pdf_page.crop(image_bbox)

    #PIL image: note if changing resolution from 150 you may need to tweak parameters in the houghlines function call
    img = cropped_page.to_image(resolution=150).original
    return image_to_dfs(img, debug)

def image_to_dfs(pil_img, debug=False):
    segments = line_detection(pil_img)

    def show_segments(segments):
        draw = ImageDraw.Draw(pil_img)
        for s in segments:
            draw.line(s, width=0, fill=(0,255,0,255))
        pil_img.show()
        input("Press Enter to continue...")

    debug_fn = show_segments if debug else None

    return segments_to_dfs(segments, lambda bbox: image_bbox_text(pil_img, bbox), debug_fn=debug_fn)

def pdf_bbox_text(pdf_page, bbox):
    page_height = pdf_page.height
    cropped_page = pdf_page.crop((bbox["x"],bbox["y"],bbox["x"]+bbox["w"],bbox["y"]+bbox["h"]))
    return cropped_page.extract_text()

def page_to_dfs(pdf_page, debug=False):
    segments = get_pdf_lines(pdf_page)

    def show_segments(segments):
        pil_img = pdf_page.to_image().original
        draw = ImageDraw.Draw(pil_img)
        for s in segments:
            draw.line(s, width = 0, fill=(0,255,0,255))
        pil_img.show()
        input("Press Enter to continue...")

    debug_fn = show_segments if debug else None

    return segments_to_dfs(segments, lambda bbox: pdf_bbox_text(pdf_page, bbox), debug_fn=debug_fn)

def segments_to_dfs(segments, bbox_reader, debug_fn=None):
    segments = filter_axial_segments(segments)
    segments = combine_segments(segments,10)

    components = segments_to_components(segments)
    for comp in components:
        df = component_to_df(comp, bbox_reader, debug_fn=debug_fn)
        if df is not None:
            yield df

if __name__ == "__main__":
    args = read_cl()
    file_ = args.file if args.file else sys.stdin
    # all_pages = pdf_obj.pages
    pdf_obj = pdfplumber.open(file_) #"/home/jason/Downloads/OLVR English and Spanish Text Master Document Updated 9-24-2020.pdf")
    # pdf_obj = pdfplumber.open("/home/jason/Downloads/foo2.pdf")
    # pdf_obj = pdfplumber.open("/tmp/pdf_image-05.pdf")

    if args.page:
        page_list = pdf_obj.pages[args.page:args.page+1]
    else:
        page_list = pdf_obj.pages
    for i,page in enumerate(page_list):
        pg_num = i+1
        df_cnt = 0
        DEBUG=args.debug

        for df in page_to_dfs(page, DEBUG):
            df_cnt += 1
            filename = f"/tmp/pg{pg_num}_table{df_cnt}.csv"
            print(f"Writing pg {pg_num}, table {df_cnt} to {filename}")
            df.to_csv(filename, header=False, index=False)

        for image in page.images:
            for df in pdf_image_to_dfs(page, image, DEBUG):
                df_cnt += 1
                filename = f"/tmp/pg{pg_num}_table{df_cnt}.csv"
                print(f"Writing pg {pg_num}, table {df_cnt} to {filename}")
                df.to_csv(filename, header=False, index=False)
