from datetime import datetime
from time import sleep
from typing import Any
import base64
import serial
import sys

ENDL = bytes([0])
OK = bytes([1])
ERROR = bytes([2])
HANDSHAKE = bytes([3])
SET_TIMESTAMP = bytes([10])
ADD_TOKEN = bytes([20])
DELETE_TOKEN = bytes([30])
GET_TOKENS = bytes([40])
WIPE_TOKENS = bytes([50])
EXIT = bytes([254])

def loop_input(msg: str, valid_values: Any) -> str:
    valid_values = list(map(str, valid_values))

    while True:
        data = input(msg)

        if data not in valid_values:
            print(f"'{data}' isn't a valid value, please enter a value in {valid_values}")
        else:
            break
    return data

def b(n: int) -> bytes:
    return bytes([n])

def process_secret(secret: str) -> bytes:
    offset = 8 - (len(secret) % 8)
    if offset != 8:
        secret += "=" * offset
    return base64.b32decode(secret, casefold=True)

def send_and_validate(value: bytes, conn):
    while True:
        print(f"Send {list(value)}")
        conn.write(value)
        resp = conn.read()
        print(f"Received {list(resp)}")
        if resp == value:
            conn.write(OK)
            break
        else:
            conn.write(ERROR)

def get_datetime_items() -> dict[str, bytes]:
    now = datetime.utcnow()

    return {
        "year": b(now.year - 2000),
        "month": b(now.month),
        "day": b(now.day),
        "hours": b(now.hour),
        "minutes": b(now.minute),
        "seconds": b(now.second)
    }

def main(argv: list[str]):
    port = argv[-1]

    conn = serial.Serial(port=port, baudrate=9600)
    print("UP + Reset to enable the USB connection")
    input("Press Enter if the device in the USB mode...")

    conn.write(HANDSHAKE)
    sleep(0.1)
    res = conn.read()
    if res != OK:
        print("A handshake could not be performed")
        print("Check the connection with your Arduino")
        return 1
    else:
        print("Handshake successfully performed")

    while True:
        print("What do you want to do?")
        print("1) Update timestamp")
        print("2) Add a new token")
        print("3) Remove a token")
        print("4) WIPE ALL THE TOKENS")
        print("5) EXIT")

        opt = loop_input(">>> ", range(1, 6))

        # Update Timestamp
        if opt == "1":
            conn.write(SET_TIMESTAMP)
            sleep(0.1)
            resp = conn.read()
            sleep(0.1)
            if resp == ERROR:
                print(f"Error in the communication: Error {resp}")
                continue
            date_items = get_datetime_items();
            send_and_validate(date_items["year"], conn)
            send_and_validate(date_items["month"], conn)
            send_and_validate(date_items["day"], conn)
            send_and_validate(date_items["hours"], conn)
            send_and_validate(date_items["minutes"], conn)
            send_and_validate(date_items["seconds"], conn)
            resp = conn.read()
            if resp != OK:
                print(f"Error in the communication: Error {resp}")
            else:
                print("Timestamp updated successfully!")
        # Add token
        elif opt == "2":
            conn.write(ADD_TOKEN)
            sleep(0.1)
            resp = conn.read()
            if resp == ERROR:
                print("The memory of the device is full")
                continue
            name = input("Enter the name of the new token (16 chars max):\n>>> ")
            name = name.strip()[:16].encode("ascii")

            key = input("Enter the OTP secret key (32 chars max):\n>>> ")
            key = process_secret(key.strip()[:32])

            for ch in name:
                conn.write(b(ch))
            conn.write(ENDL)
            for ch in key:
                conn.write(b(ch))
            conn.write(ENDL)

            resp = conn.read()
            if resp == ERROR:
                print("Error trying to add the token, try again")
            else:
                print("Token added successfully!")
        # Wipe tokens
        elif opt == "4":
            conn.write(WIPE_TOKENS)
            sleep(0.1)
            _ = conn.read()
            resp = conn.read()
            if resp == OK:
                print("All the tokens wipped successfully!")
        elif opt == "5":
            return 0

if __name__ == "__main__":
    attrs = sys.argv
    if len(attrs) == 1:
        print("You need to specify the Arduino's serial port")
        print("Example:")
        print(f"python {attrs[0]} /dev/ttyUSB0")
        sys.exit(1)

    sys.exit(main(attrs[1:]))