########
# MATRIX MUPLICATIONS
########

## MAPPER

#!/usr/bin/env python
import sys

def main():
    for line in sys.stdin:
        line = line.strip()
        if not line:
            continue
            
        parts = line.split(',')
        
        if len(parts) == 3:
            # Matrix element: row,col,value
            row, col, value = parts
            # Emit: row \t M,col,value
            print row + "\tM," + col + "," + value
            
        elif len(parts) == 2:
            # Vector element: index,value
            index, value = parts
            # Emit: index \t V,value
            print index + "\tV," + value

if __name__ == "__main__":
    main()


## REDUCER
#!/usr/bin/env python
import sys
from collections import defaultdict

def main():
    current_key = None
    matrix_elements = []  # Store matrix elements for current row
    vector_values = defaultdict(float)  # Store vector values
    
    for line in sys.stdin:
        line = line.strip()
        if not line:
            continue
            
        parts = line.split('\t')
        if len(parts) < 2:
            continue
            
        key = parts[0]
        value = parts[1]
        
        if current_key != key:
            # Process previous key if we have one
            if current_key is not None and matrix_elements:
                process_row(current_key, matrix_elements, vector_values)
            
            # Reset for new key
            current_key = key
            matrix_elements = []
        
        # Parse the value
        if value.startswith('M,'):
            # Matrix element: M,col,value
            val_parts = value.split(',')
            if len(val_parts) >= 3:
                col = int(val_parts[1])
                matrix_val = float(val_parts[2])
                matrix_elements.append((col, matrix_val))
            
        elif value.startswith('V,'):
            # Vector element: V,value
            val_parts = value.split(',')
            if len(val_parts) >= 2:
                vector_val = float(val_parts[1])
                vector_values[current_key] = vector_val
    
    # Process the last key
    if current_key is not None and matrix_elements:
        process_row(current_key, matrix_elements, vector_values)

def process_row(row_key, matrix_elements, vector_values):
    """Calculate dot product for a matrix row with the vector"""
    result = 0.0
    for col, matrix_val in matrix_elements:
        vector_key = str(col)
        if vector_key in vector_values:
            result += matrix_val * vector_values[vector_key]
    
    # Output the result for this row
    print str(row_key) + "\t" + str(result)

if __name__ == "__main__":
    main()

##Vector.txt 
0,7.0
1,8.0
2,9.0

## Matrix .txt

0,0,1.0
0,1,2.0
0,2,3.0
1,0,4.0
1,1,5.0
1,2,6.0 

## COMMAND
chmod +x mapper.py
chmod +x reducer.py


hadoop jar /usr/lib/hadoop-mapreduce/hadoop-streaming.jar \
    -input /matrix_input/input.txt \
    -output /matrix_output \
    -mapper mapper.py \
    -reducer reducer.py \
    -file mapper.py \
    -file reducer.py

hdfs dfs -cat /matrix_output/part-00000

