1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
#!/usr/bin/env python3
from math import prod
from typing import NamedTuple
data: str
class Packet(NamedTuple):
pass
class Packet(NamedTuple):
version: int
type: int
value: int
subpackets: list[Packet]
def calculate(self) -> int:
f = lambda p: p.calculate()
# START PART 1
return self.version + sum(map(f, self.subpackets))
# END PART 1 START PART 2
match self.type:
case 0:
return sum(map(f, self.subpackets))
case 1:
return prod(map(f, self.subpackets))
case 2:
return min(map(f, self.subpackets))
case 3:
return max(map(f, self.subpackets))
case 4:
return self.value
case 5:
return f(self.subpackets[0]) > f(self.subpackets[1])
case 6:
return f(self.subpackets[0]) < f(self.subpackets[1])
case 7:
return f(self.subpackets[0]) == f(self.subpackets[1])
# END PART 2
def solve() -> Packet:
global data
v = int(data[:3], 2)
t = int(data[3:6], 2)
data = data[6:]
if t == 4:
val = ""
while data[0] == "1":
val += data[1:5]
data = data[5:]
val += data[1:5]
data = data[5:]
return Packet(v, t, int(val, 2), [])
l = data[0]
data = data[1:]
if l == "0":
length = int(data[:15], 2)
data = data[15:]
oldlen = len(data)
subpackets = []
while oldlen - len(data) < length:
subpackets.append(solve())
return Packet(v, t, 0, subpackets)
n = int(data[:11], 2)
data = data[11:]
return Packet(v, t, 0, [solve() for _ in range(n)])
def main() -> None:
global data
with open("input", "r", encoding="utf-8") as f:
data = "".join(bin(n)[2:].zfill(8) for n in bytes.fromhex(f.read().strip()))
print(solve().calculate())
if __name__ == "__main__":
main()
|