#!/usr/bin/python3

import sys
from random import shuffle
from collections import defaultdict
import argparse

argparser = argparse.ArgumentParser()
argparser.add_argument('--show-zero', '-0', action='store_true')
argparser.add_argument('max', type=int, nargs='?', default=None)
argparser.add_argument('--display-only', type=str)

args = argparser.parse_args()
if args.display_only is not None:
    args.display_only = [ int(s) for s in args.display_only.split(",") ]
    if args.max is None:
        args.max = max(args.display_only)

if args.max is None:
    raise Exception("missing required argument 'max'")

do_zero = args.show_zero

def filter(s):
    keep = set(s)
    for i in set(keep):
        for j in set(keep):
            if i != j and j % i == 0:
                keep.remove(i)
                break
#    print("**", s, "->", keep)
    return keep

def graph(n):
    g = defaultdict(list)
    for i in range(1, n+1):
        for j in range(1, i):
            if i % j == 0:
                g[i].append(j)
    for k,v in g.items():
        g[k] = filter(v)
    return g

# strip out all the nodes in graph g,
# except those listed in display_only
def apply_display_only(g, display_only):
    keys_to_remove = [ k for k in g.keys() if k not in display_only ]
    for k in keys_to_remove:
        del g[k]
    for k, vals in g.items():
        g[k] = [ v for v in vals if v in display_only ]

base_colors=["#800000", "#008000", "#000080", "#800080"]
print("digraph {")
print('  rankdir=RL; node [shape="circle"]; bgcolor="transparent";')

g = graph(args.max)
if args.display_only is not None:
    apply_display_only(g, args.display_only)

indegree = defaultdict(int)
for v in g.values():
    for vv in v:
        indegree[vv] += 1

def is_uninteresting_prime(n):
    # if n == 23:
    #     import pdb; pdb.set_trace()
    if indegree[n] > 0:
        return False
    if g[n] == {1}:
        return True
    return False

pcolors = {2: "black", 3: "#800000", 5: "#008000", 7: "#0000C0", 11: "#808080" }
pstyles = {2: "solid", 3: "dashed", 5: "dotted", 7: "tapered" }

for k, v in g.items():
    if is_uninteresting_prime(k):
        print("Skipping", k, file=sys.stderr)
        continue
    targets = sorted(list(v))
    for i in range(len(targets)):
        d = k/targets[i]
        color = pcolors[d] if d in pcolors else "gray"
        style = pstyles[d] if d in pstyles else "bold"
        print(f' {k} -> {targets[i]} [color="{color}" style="{style}"];')
    if do_zero:
        if indegree[k] == 0:
            print(f' x{k} [style="invis"]; x{k} -> {k} [style="invis"]')
            print(f' {0} -> x{k} [style="dotted"]')
    # vs = ",".join([ str(s) for s in sorted(list(v)) ])
    # print(f"  {k} -> {vs};")

print("}")
