#!/usr/bin/python3
# Copyright 2024-2025 Chris Hofstaedtler <chris@hofstaedtler.name>
# SPDX-License-Identifier: GPL-3.0+
#
# Use `black --line-length 120 bin/reform-mcu-tool` to reformat.
#
# Installation
#
# Install `python3-usb1` as a pre-requisite. No other dependencies should be necessary.
#
# Do not run shellcheck on this file
# shellcheck disable=SC1071

_DOC = """
Tool to interact with Microcontrollers used by MNT Research in Reform projects.

It can currently talk to the pocket sysctl firmware, and to the RP2040 bootrom.
"""

_EPILOG = """
Example usage:

To reboot the pocket sysctl into the bootrom:

  $ sudo reform-mcu-tool bootsel pocket-sysctl-1.0

To reset the RP2040 bootrom back into the application:

  $ sudo reform-mcu-tool reset rp2040-boot

"""

import argparse
import struct
import sys
from dataclasses import dataclass
from typing import Tuple

try:
    import usb1
except ModuleNotFoundError as except_inst:
    print('E: Library "usb1" not found, please install python3-usb1.', file=sys.stderr)
    sys.exit(4)


USB_VID_PIDCODES = 0x1209
USB_VID_RASPBERRY = 0x2E8A
USB_PID_MNT_POCKET_REFORM_INPUT_10 = 0x6D06
USB_PID_MNT_POCKET_REFORM_SYSCTL_10 = 0x6D07
USB_PID_RASPBERRY_RP2040_BOOT = 0x0003
USB_PID_RASPBERRY_RP2350_BOOT = 0x000F

IS_APP = 0b0000_0001
IS_RP_BOOTROM = 0b0000_0010

MCU_TYPES = {
    "pocket-input-1.0": (USB_VID_PIDCODES, USB_PID_MNT_POCKET_REFORM_INPUT_10, IS_APP),
    "pocket-sysctl-1.0": (USB_VID_PIDCODES, USB_PID_MNT_POCKET_REFORM_SYSCTL_10, IS_APP),
    "rp2040-boot": (USB_VID_RASPBERRY, USB_PID_RASPBERRY_RP2040_BOOT, IS_RP_BOOTROM),
    "rp2350-boot": (USB_VID_RASPBERRY, USB_PID_RASPBERRY_RP2350_BOOT, IS_RP_BOOTROM),
}

RESET_INTERFACE_SUBCLASS = 0
RESET_INTERFACE_PROTOCOL = 1

RESET_REQUEST_BOOTSEL = 1
RESET_REQUEST_FLASH = 2

PICOBOOT_MAGIC = 0x431FD10B
PC_REBOOT = 0x2
PICOBOOT_IF_RESET = 0x41


@dataclass(kw_only=True)
class BOSDescriptor:
    """Binary Device Object Store"""

    bLength: int
    bDescriptorType: int
    wTotalLength: int
    bNumDeviceCaps: int
    data: bytearray

    _header_format = "<BBHB"
    DESCRIPTOR_TYPE = 0xF

    @classmethod
    def parse(cls, data: bytearray) -> "BOSDescriptor":
        header_length = struct.calcsize(cls._header_format)
        bLength, bDescriptorType, wTotalLength, bNumDeviceCaps = struct.unpack(
            cls._header_format, data[0:header_length]
        )
        if bDescriptorType != cls.DESCRIPTOR_TYPE:
            raise TypeError(
                f"Descriptor returned by USB device is of type {bDescriptorType} instead of {cls.DESCRIPTOR_TYPE}"
            )

        return cls(
            bLength=bLength,
            bDescriptorType=bDescriptorType,
            wTotalLength=wTotalLength,
            bNumDeviceCaps=bNumDeviceCaps,
            data=data[header_length:],
        )


@dataclass(kw_only=True)
class BOSDeviceCapability:
    bLength: int
    bDescriptorType: int
    bDevCapabilityType: int
    data: bytearray

    _header_format = "<BBB"
    DESCRIPTOR_TYPE = 0x10  # DEVICE CAPABILITY

    @classmethod
    def parse(cls, data: bytearray) -> "BOSDeviceCapability":
        header_length = struct.calcsize(cls._header_format)
        bLength, bDescriptorType, bDevCapabilityType = struct.unpack(cls._header_format, data[0:header_length])
        if bDescriptorType != cls.DESCRIPTOR_TYPE:
            raise TypeError(
                f"Descriptor returned by USB device is of type {bDescriptorType} instead of {cls.DESCRIPTOR_TYPE}"
            )

        return cls(
            bLength=bLength,
            bDescriptorType=bDescriptorType,
            bDevCapabilityType=bDevCapabilityType,
            data=data[header_length : header_length + bLength],
        )


@dataclass(kw_only=True)
class PlatformCapabilityDescriptor:
    """Defines a device capability specific to a particular platform/operating system"""

    bLength: int
    bDescriptorType: int
    bDevCapabilityType: int
    # uint8_t bReserved
    PlatformCapabilityUUID: str  #  16byte
    data: bytearray  # bLength - 20

    _header_format = "<BBBB16s"
    DEV_CAPABILITY_TYPE = 0x05

    @classmethod
    def parse(cls, data: bytearray) -> "PlatformCapabilityDescriptor":
        header_length = struct.calcsize(cls._header_format)
        bLength, bDescriptorType, bDevCapabilityType, _, PlatformCapabilityUUID = struct.unpack(
            cls._header_format, data[0:header_length]
        )
        if bDescriptorType != BOSDeviceCapability.DESCRIPTOR_TYPE:
            raise TypeError(f"Descriptor is of type {bDescriptorType} instead of {BOSDeviceCapability.DESCRIPTOR_TYPE}")
        if bDevCapabilityType != cls.DEV_CAPABILITY_TYPE:
            raise TypeError(f"DevCapability is of type {bDevCapabilityType} instead of {cls.DEV_CAPABILITY_TYPE}")

        return cls(
            bLength=bLength,
            bDescriptorType=bDescriptorType,
            bDevCapabilityType=bDevCapabilityType,
            PlatformCapabilityUUID=PlatformCapabilityUUID,
            data=data[header_length : header_length + bLength],
        )


@dataclass(kw_only=True)
class DS20Descriptor:
    dwVersion: int
    wLength: int
    bVendorCode: int
    bAltEnumCmd: int

    _format = "<IHBB"

    UUID = bytes([0x63, 0xEC, 0x0A, 0x01, 0x74, 0xF5, 0xCD, 0x52, 0x9D, 0xDA, 0x28, 0x52, 0x55, 0x0D, 0x94, 0xF0])

    @classmethod
    def parse(cls, data: bytearray) -> "DS20Descriptor":
        dwVersion, wLength, bVendorCode, bAltEnumCmd = struct.unpack(cls._format, data)
        return cls(dwVersion=dwVersion, wLength=wLength, bVendorCode=bVendorCode, bAltEnumCmd=bAltEnumCmd)


def device_reset(handle: usb1.USBDeviceHandle, reset_interface: int, reset_request: int):
    handle.claimInterface(reset_interface)
    try:
        handle.controlWrite(
            usb1.TYPE_CLASS | usb1.RECIPIENT_INTERFACE, reset_request, 0, reset_interface, b"", timeout=2000
        )
    except (usb1.USBErrorIO, usb1.USBErrorPipe):
        pass  # Expected. MCU has reset and vanishes from USB. Exact error appears to be timing/fw-dependent.
    else:
        handle.releaseInterface(reset_interface)


def picoboot_reset(handle: usb1.USBDeviceHandle, picoboot_interface: usb1.USBInterfaceSetting):
    handle.claimInterface(picoboot_interface.getNumber())
    out_address = picoboot_interface[0].getAddress()
    handle.clearHalt(out_address)
    in_address = picoboot_interface[1].getAddress()
    handle.clearHalt(in_address)

    pc = 0
    sp = 0
    delay_ms = 500
    reboot_cmd = struct.pack("<LLL", pc, sp, delay_ms)
    reboot_cmd_padded = reboot_cmd + struct.pack("<L", 0)

    token = 1
    picoboot_cmd = struct.pack("<LLBBHL", PICOBOOT_MAGIC, token, PC_REBOOT, len(reboot_cmd), 0, 0) + reboot_cmd_padded
    sent = handle.bulkWrite(out_address, picoboot_cmd, timeout=3000)
    if sent != 32:
        raise ValueError(f"Expected to send picoboot_cmd of size 32, but sent {sent}")

    received = handle.bulkRead(in_address, 1, timeout=10000)

    handle.releaseInterface(picoboot_interface.getNumber())


def find_reset_interface(device: usb1.USBDevice) -> int | None:
    for setting in device.iterSettings():
        if (
            setting.getClass() == 0xFF
            and setting.getSubClass() == RESET_INTERFACE_SUBCLASS
            and setting.getProtocol() == RESET_INTERFACE_PROTOCOL
            and setting.getNumEndpoints() == 0
        ):
            return setting.getNumber()
    return None


def find_picoboot_interface(device: usb1.USBDevice) -> usb1.USBInterfaceSetting | None:
    for setting in device.iterSettings():
        out_address = setting[0].getAddress()
        in_address = setting[1].getAddress()
        if (
            setting.getClass() == 0xFF
            and setting.getNumEndpoints() == 2
            and out_address & 0x80 == 0
            and in_address & 0x80 == 0x80
        ):
            return setting
    return None


def _read_bos_descriptor(handle: usb1.USBDeviceHandle) -> None | BOSDescriptor:
    request_type = usb1.TYPE_STANDARD | usb1.RECIPIENT_DEVICE
    request_value = BOSDescriptor.DESCRIPTOR_TYPE << 8
    timeout = 1000

    # Read descriptor header first
    try:
        bos_data = handle.controlRead(request_type, usb1.REQUEST_GET_DESCRIPTOR, request_value, 0, 5, timeout=timeout)
    except usb1.USBErrorPipe:
        return None  # unsupported

    descriptor = BOSDescriptor.parse(bos_data)

    # Read full descriptor now that we know how long it is
    bos_data = handle.controlRead(
        request_type, usb1.REQUEST_GET_DESCRIPTOR, request_value, 0, descriptor.wTotalLength, timeout=timeout
    )
    if len(bos_data) != descriptor.wTotalLength:
        # USB device returned too little or too much data
        return None

    return BOSDescriptor.parse(bos_data)


def _read_bos_ds20_version(handle: usb1.USBDeviceHandle) -> str | None:
    bos_descriptor = _read_bos_descriptor(handle)
    if bos_descriptor is None:
        return None

    offset = 0
    while offset < len(bos_descriptor.data):
        device_cap = BOSDeviceCapability.parse(bos_descriptor.data[offset:])
        offset += device_cap.bLength
        try:
            platform_desc = PlatformCapabilityDescriptor.parse(bos_descriptor.data[offset - device_cap.bLength :])
        except TypeError as except_inst:
            continue

        # See https://github.com/fwupd/fwupd/blob/main/docs/ds20.md
        if platform_desc.PlatformCapabilityUUID != DS20Descriptor.UUID:
            continue
        ds20_desc = DS20Descriptor.parse(platform_desc.data)
        if ds20_desc.dwVersion < 0x0001090E:
            # Minimum version is 0x0001090e (1.9.14).
            continue
        if ds20_desc.bAltEnumCmd != 0:
            continue

        if ds20_desc.bVendorCode == 0:
            # Probably invalid, will read something we do not want.
            continue

        # Read quirk data from device (in UTF-8).
        bmRequestType = usb1.TYPE_VENDOR | usb1.RECIPIENT_DEVICE
        response = handle.controlRead(
            usb1.TYPE_VENDOR | usb1.RECIPIENT_DEVICE,
            ds20_desc.bVendorCode,
            0x00,  # wValue per ds20.md
            0x07,  # wIndex per ds20.md
            ds20_desc.wLength,
            timeout=2000,
        )

        if not response:
            return None

        response_str = response.decode("utf-8", errors="ignore")
        quirk_data = {}
        for line in response_str.splitlines():
            line = line.strip()
            if line.startswith(";") or not "=" in line:
                continue
            k, v = line.split("=", maxsplit=1)
            quirk_data[k] = v
        version = quirk_data.get("Version")
        if version is None:
            continue
        version_format = quirk_data.get("VersionFormat")
        if version_format != "number":
            print(f"W: VersionFormat {version_format} not understood, ignoring")
            continue
        return int(version)

    return None


def read_bos_ds20_version(device: usb1.USBDevice) -> str | None:
    """Read BOS DS20 descriptor and extract Version string."""
    try:
        handle = device.open()
        return _read_bos_ds20_version(handle)
    except (usb1.USBError, usb1.USBErrorAccess, usb1.USBErrorNoDevice, usb1.USBErrorTimeout) as except_inst:
        return None
    finally:
        handle.close()


def action_bootsel(args, device: usb1.USBDevice):
    target_flags = MCU_TYPES[args.target][2]
    if target_flags & IS_RP_BOOTROM:
        print("E: Device is already in bootrom.", file=sys.stderr)
        return 2

    if (reset_interface := find_reset_interface(device)) is None:
        print("E: Could not find Reset USB Interface.", file=sys.stderr)
        return 1

    serial_number = device.getSerialNumber()
    print(f"I: Resetting device with serial {serial_number} into BOOTSEL")
    handle = device.open()
    device_reset(handle, reset_interface, RESET_REQUEST_BOOTSEL)
    print(f"I: You may now use: $ picotool info --ser {serial_number}")
    return 0


def action_reset(args, device: usb1.USBDevice):
    target_flags = MCU_TYPES[args.target][2]
    if target_flags & IS_APP:
        if (reset_interface := find_reset_interface(device)) is None:
            print("E: Could not find Reset USB Interface.", file=sys.stderr)
            return 1

        print(f"I: Resetting device")
        handle = device.open()
        device_reset(handle, reset_interface, RESET_REQUEST_FLASH)

    elif target_flags & IS_RP_BOOTROM:
        if (reset_interface := find_picoboot_interface(device)) is None:
            print("E: Could not find PICOBOOT USB Interface.", file=sys.stderr)
            return 1

        print(f"I: Resetting bootrom into application")
        handle = device.open()
        picoboot_reset(handle, reset_interface)
    return 0


def action_list(args, usb_context: usb1.USBContext):
    for device in usb_context.getDeviceIterator(skip_on_error=True):
        vid = device.getVendorID()
        pid = device.getProductID()

        for mcu_name, (mcu_vid, mcu_pid, flags) in MCU_TYPES.items():
            if (vid, pid) != (mcu_vid, mcu_pid):
                continue

            device_info = (
                f"Target {mcu_name} ID {vid:04x}:{pid:04x} Serial# {device.getSerialNumber()} "
                f"USB bus {device.getBusNumber()} address {device.getDeviceAddress()}"
            )

            version_info = read_bos_ds20_version(device)
            if version_info:
                device_info += f" Version {version_info}"

            print(device_info)
            break

    return 0


def parse_args():
    parser = argparse.ArgumentParser(
        prog="reform-mcu-tool", description=_DOC, epilog=_EPILOG, formatter_class=argparse.RawTextHelpFormatter
    )
    subparsers = parser.add_subparsers(help="Action to execute")

    parser_bootsel = subparsers.add_parser("bootsel", help="Reboot MCU into BOOTSEL mode")
    parser_bootsel.set_defaults(func=action_bootsel)
    parser_bootsel.add_argument(
        "target",
        choices=MCU_TYPES.keys(),
        metavar="TARGET",
        help=f"Target device to operate on. Choices: {', '.join(MCU_TYPES.keys())}",
    )

    parser_reset = subparsers.add_parser("reset", help="Reboot MCU into application mode")
    parser_reset.set_defaults(func=action_reset)
    parser_reset.add_argument(
        "target",
        choices=MCU_TYPES.keys(),
        metavar="TARGET",
        help=f"Target device to operate on. Choices: {', '.join(MCU_TYPES.keys())}",
    )

    parser_list = subparsers.add_parser("list", help="List USB devices matching known VID/PIDs")
    parser_list.set_defaults(func=action_list)

    args = parser.parse_args()
    if "func" not in args:
        parser.print_help()
        parser.exit()
    return args


def run(args, usb_context: usb1.USBContext) -> int:
    if "target" in args:
        (vid, pid, _) = MCU_TYPES[args.target]
        device = usb_context.getByVendorIDAndProductID(vid, pid, skip_on_error=True)
        if not device:
            print(f"E: USB device with Vendor-ID {vid} Product-ID {pid} not found.", file=sys.stderr)
            return 1

        print(
            f"I: Found {device.getManufacturer()} {device.getProduct()} "
            f"on bus {device.getBusNumber()} address {device.getDeviceAddress()}"
        )
        return args.func(args, device)
    else:
        return args.func(args, usb_context)


def main() -> int:
    args = parse_args()

    try:
        with usb1.USBContext() as usb_context:
            return run(args, usb_context)
    except usb1.USBErrorAccess:
        print("E: Accessing USB devices failed. -- Maybe you need to be root / use sudo.", file=sys.stderr)
        return 3


if __name__ == "__main__":
    sys.exit(main())
