aboutsummaryrefslogtreecommitdiff
path: root/2021/16/puzzles.py
blob: 75e29d1e9379144da1376fc0a53872247df88f1e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3

from math import prod
from typing import NamedTuple

data: str


class Packet(NamedTuple):
	pass


class Packet(NamedTuple):
	version: int
	type: int
	value: int
	subpackets: list[Packet]

	def calculate(self) -> int:
		f = lambda p: p.calculate()

		# START PART 1
		return self.version + sum(map(f, self.subpackets))
		# END PART 1 START PART 2
		match self.type:
			case 0:
				return sum(map(f, self.subpackets))
			case 1:
				return prod(map(f, self.subpackets))
			case 2:
				return min(map(f, self.subpackets))
			case 3:
				return max(map(f, self.subpackets))
			case 4:
				return self.value
			case 5:
				return self.subpackets[0].calculate() > self.subpackets[1].calculate()
			case 6:
				return self.subpackets[0].calculate() < self.subpackets[1].calculate()
			case 7:
				return self.subpackets[0].calculate() == self.subpackets[1].calculate()
		# END PART 2


def solve() -> Packet:
	global data

	v = int(data[:3], 2)
	t = int(data[3:6], 2)
	data = data[6:]

	if t == 4:
		val = ""
		while data[0] == "1":
			val += data[1:5]
			data = data[5:]
		val += data[1:5]
		data = data[5:]
		return Packet(v, t, int(val, 2), [])

	l = data[0]
	data = data[1:]

	if l == "0":
		length = int(data[:15], 2)
		data = data[15:]
		oldlen = len(data)

		subpackets = []
		while oldlen - len(data) < length:
			subpackets.append(solve())

		return Packet(v, t, 0, subpackets)

	n = int(data[:11], 2)
	data = data[11:]
	return Packet(v, t, 0, [solve() for _ in range(n)])


def main() -> None:
	global data

	with open("input", "r", encoding="utf-8") as f:
		data = "".join(bin(n)[2:].zfill(8) for n in bytes.fromhex(f.read().strip()))

	print(solve().calculate())


if __name__ == "__main__":
	main()