aboutsummaryrefslogblamecommitdiff
path: root/2021/11/puzzles.py
blob: 19fb5346e891004ff38c6b033b29a068578adb7b (plain) (tree)
1
2
3
4
5
6
7
8
9
                      
                                           
 
                                                 

                   
                                                    
















                                                                                             
                                         








                                 
                                                               









                                                                    
                                                                             






                          
#!/usr/bin/env python3

from itertools import chain, count, product

from libaoc import map2d, matrix, read_int_matrix

rows, cols = -1, -1


def flash(grid: matrix[int], i: int, j: int) -> int:
	acc = 1
	grid[i][j] = -1

	for di, dj in ((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)):
		ni = i + di
		nj = j + dj
		if 0 <= ni < rows and 0 <= nj < cols and grid[ni][nj] != -1:
			grid[ni][nj] += 1
			if grid[ni][nj] > 9:
				acc += flash(grid, ni, nj)

	return acc


def main() -> None:
	global rows, cols

	with open("input", "r", encoding="utf-8") as f:
		grid = read_int_matrix(f)

	rows = len(grid)
	cols = len(grid[0])

	acc = 0
	# START PART 1
	for _ in range(100):
	# END PART 1 START PART 2
	for step in count(1):
	# END PART 2
		grid = list(map2d(lambda n: n + 1, grid, list))

		for i, j in product(range(rows), range(cols)):
			if grid[i][j] > 9:
				acc += flash(grid, i, j)

		# START PART 2
		if sum(chain.from_iterable(grid)) == -(rows * cols):
			print(step)
			return
		# END PART 2

		grid = list(map2d(lambda n: 0 if n == -1 else n, grid, list))

	# START PART 1
	print(acc)
	# END PART 1


if __name__ == "__main__":
	main()