#!/usr/bin/env python3
from math import prod
from typing import NamedTuple
data: str
class Packet(NamedTuple):
version: int
type: int
value: int
subpackets: list['Packet']
def calculate(self) -> int:
f = lambda p: p.calculate()
# START PART 1
return self.version + sum(map(f, self.subpackets))
# END PART 1 START PART 2
match self.type:
case 0:
return sum(map(f, self.subpackets))
case 1:
return prod(map(f, self.subpackets))
case 2:
return min(map(f, self.subpackets))
case 3:
return max(map(f, self.subpackets))
case 4:
return self.value
case 5:
return f(self.subpackets[0]) > f(self.subpackets[1])
case 6:
return f(self.subpackets[0]) < f(self.subpackets[1])
case 7:
return f(self.subpackets[0]) == f(self.subpackets[1])
# END PART 2
def solve() -> Packet:
global data
v = int(data[:3], 2)
t = int(data[3:6], 2)
data = data[6:]
if t == 4:
val = ""
while data[0] == "1":
val += data[1:5]
data = data[5:]
val += data[1:5]
data = data[5:]
return Packet(v, t, int(val, 2), [])
l = data[0]
data = data[1:]
if l == "0":
length = int(data[:15], 2)
data = data[15:]
oldlen = len(data)
subpackets = []
while oldlen - len(data) < length:
subpackets.append(solve())
return Packet(v, t, 0, subpackets)
n = int(data[:11], 2)
data = data[11:]
return Packet(v, t, 0, [solve() for _ in range(n)])
def main() -> None:
global data
with open("input", "r", encoding="utf-8") as f:
data = "".join(bin(n)[2:].zfill(8) for n in bytes.fromhex(f.read().strip()))
print(solve().calculate())
if __name__ == "__main__":
main()