#!/usr/bin/env python3

from string import ascii_lowercase

import openpyxl

wb = openpyxl.load_workbook("Linkages to social goals.xlsx")

ws_direct = wb["Direct links"]
direct_per_participant = {}
for row in ws_direct.iter_rows(min_row=2):
    participant = int(row[0].value)
    if participant not in direct_per_participant:
        direct_per_participant[participant] = set()
    raw = [cell.value for cell in row[1:]]
    data = [int(str(r).rstrip(ascii_lowercase)) for r in raw if r is not None]
    links = list(zip(data[:-1], data[1:]))
    direct_per_participant[participant].update(links)

ws_indirect = wb["Indirect links"]
indirect_per_participant = {}
for row in ws_indirect.iter_rows(min_row=2):
    participant = int(row[0].value)
    if participant not in indirect_per_participant:
        indirect_per_participant[participant] = set()
    reference = int(str(row[1].value).rstrip(ascii_lowercase))
    raw = [cell.value for cell in row[2:]]
    data = [int(str(r).rstrip(ascii_lowercase)) for r in raw if r is not None]
    links = [(reference, d) for d in data]
    indirect_per_participant[participant].update(links)

for participant in indirect_per_participant:
    indirect_per_participant[participant] -= direct_per_participant[participant]

direct_links = {}
for dl in direct_per_participant.values():
    for link in dl:
        if link in direct_links:
            direct_links[link] += 1
        else:
            direct_links[link] = 1

indirect_links = {}
for idl in indirect_per_participant.values():
    for link in idl:
        if link in indirect_links:
            indirect_links[link] += 1
        else:
            indirect_links[link] = 1

counts = {link: [n, 0] for link, n in direct_links.items()}
for link, n in indirect_links.items():
    if link in counts:
        counts[link][1] = n
    else:
        counts[link] = [0, n]

vertices = set()
for k in counts:
    vertices |= set(k)
n_vertices = max(vertices) + 1

wb_out = openpyxl.Workbook()

ws_matrix = wb_out.active
ws_matrix.title = "matrix"
for i in range(1, n_vertices):
    ws_matrix.cell(row=1, column=i+1, value=f"{i}")
    ws_matrix.cell(row=i+1, column=1, value=f"{i}")

td = 0
ti = 0
for (l, r), (nd, ni) in sorted(counts.items()):
    ws_matrix.cell(row=l+1, column=r+1, value=f"{nd},{ni}")
    td += nd
    ti += ni

for threshold in [1,3,4,5]:
    ws_counts = wb_out.create_sheet(f"counts_{threshold}")
    row = 1
    sd = 0
    si = 0
    for (l, r), (nd, ni) in sorted(counts.items()):
        if nd + ni >= threshold:
            sd += nd
            si += ni
            ws_counts.cell(row=row, column=1, value=f"{l}")
            ws_counts.cell(row=row, column=2, value=f"{r}")
            ws_counts.cell(row=row, column=3, value=f"{nd},{ni}")
            row += 1
    ws_counts.cell(row=row+1, column=1, value="count (direct/indirect/all)")
    ws_counts.cell(row=row+1, column=5, value=f"{sd}")
    ws_counts.cell(row=row+1, column=6, value=f"{si}")
    ws_counts.cell(row=row+1, column=7, value=f"{sd+si}")
    ws_counts.cell(row=row+2, column=1, value="total (direct/indirect/all)")
    ws_counts.cell(row=row+2, column=5, value=f"{td}")
    ws_counts.cell(row=row+2, column=6, value=f"{ti}")
    ws_counts.cell(row=row+2, column=7, value=f"{td+ti}")
    ws_counts.cell(row=row+3, column=1, value="percentage (direct/indirect/all)")
    ws_counts.cell(row=row+3, column=5, value=f"{100*sd/td:.1f}")
    ws_counts.cell(row=row+3, column=6, value=f"{100*si/ti:.1f}")
    ws_counts.cell(row=row+3, column=7, value=f"{100*(sd+si)/(td+ti):.1f}")

wb_out.save("Implication matrix.xlsx")
