#!/usr/bin/python

import argparse
import json
import subprocess
from collections import defaultdict


def get_task_data():
    command = (
        "task +PENDING or +WAITING -COMPLETED -DELETED export | "
        "jq '[.[] | {uuid: .uuid, id, depends: .depends, description: .description, status: .status }]'"
    )
    output = subprocess.check_output(command, shell=True)
    return json.loads(output)


def parse_task_data(data):
    dependency_graph = defaultdict(list)
    task_details = {}
    dependent_tasks = set()

    for task in data:
        task_id = task["uuid"]
        task_details[task_id] = {
            "id": task.get("id", "?"),
            "description": task.get("description", "No description"),
            "status": task.get("status", "Unknown status"),
        }
        if task["depends"]:
            for dependency in task["depends"]:
                dependency_graph[dependency].append(task_id)
                dependent_tasks.add(task_id)

    root_tasks = set(task_details.keys()) - dependent_tasks
    return task_details, dependency_graph, root_tasks


def get_all_parents(task_id, dependency_graph):
    return [
        parent for parent, children in dependency_graph.items() if task_id in children
    ]


def build_ascii_dag(
    task_id,
    task_details,
    dependency_graph,
    prefix="",
    is_last=True,
    show_id=True,
    visited=None,
):
    if visited is None:
        visited = set()

    if task_id in visited:
        return [f"{prefix}{'└── ' if is_last else '├── '}... (cycle detected)"]

    visited.add(task_id)

    task_info = task_details[task_id]
    task_line = f"{prefix}{'└── ' if is_last else '├── '}{task_info['id'] + ': ' if show_id else ''}{task_info['description']} ({task_info['status']})"
    lines = [task_line]

    children = dependency_graph.get(task_id, [])
    for idx, child in enumerate(children):
        child_is_last = idx == len(children) - 1
        child_prefix = prefix + ("    " if is_last else "│   ")
        lines.extend(
            build_ascii_dag(
                child,
                task_details,
                dependency_graph,
                child_prefix,
                child_is_last,
                show_id,
                visited.copy(),
            )
        )

    return lines


def render_dependency_dag(task_details, dependency_graph, root_tasks, show_id):
    dag_lines = []
    global_visited = set()

    def dfs(task_id, prefix="", is_last=True, visited=None):
        if visited is None:
            visited = set()

        if task_id in visited:
            return

        visited.add(task_id)
        global_visited.add(task_id)

        task_info = task_details[task_id]
        task_line = f"{prefix}{'└── ' if is_last else '├── '}{str(task_info['id']) + ': ' if show_id else ''}{task_info['description']} ({task_info['status']})"
        dag_lines.append(task_line)

        children = dependency_graph.get(task_id, [])
        for idx, child in enumerate(children):
            child_is_last = idx == len(children) - 1
            child_prefix = prefix + ("    " if is_last else "│   ")
            dfs(child, child_prefix, child_is_last, visited.copy())

    root_tasks_with_children = [
        root for root in root_tasks if dependency_graph.get(root, [])
    ]
    for root in sorted(
        root_tasks_with_children,
        key=lambda x: len(dependency_graph.get(x, [])),
        reverse=True,
    ):
        if root not in global_visited:
            dfs(root)
            dag_lines.append("")

    return "\n".join(dag_lines).rstrip()


def main(args):
    data = get_task_data()
    task_details, dependency_graph, root_tasks = parse_task_data(data)
    ascii_dag = render_dependency_dag(
        task_details, dependency_graph, root_tasks, show_id=args.show_id
    )
    print(ascii_dag)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generates a task dependency DAG for Taskwarrior tasks."
    )
    parser.add_argument(
        "--show-id",
        action="store_true",
        default=False,
        help="Include task IDs in the output.",
    )
    args = parser.parse_args()
    main(args)
