117 lines
2.7 KiB
Python
117 lines
2.7 KiB
Python
with open('input.text', 'r') as file:
|
|
data: list[str] = file.readlines()
|
|
|
|
class Operator:
|
|
def __init__(self):
|
|
pass
|
|
|
|
class Add(Operator):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def compute(self, x: int, y: int) -> int:
|
|
return x + y
|
|
|
|
class Multiply(Operator):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def compute(self, x: int, y: int) -> int:
|
|
return x * y
|
|
|
|
class Concat(Operator):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def compute(self, x: int, y: int) -> int:
|
|
return int(str(x) + str(y))
|
|
|
|
class Equation:
|
|
result: int
|
|
constants: list[int]
|
|
|
|
def __init__(self, result: int, constants: list[int]):
|
|
self.result = result
|
|
self.constants = constants
|
|
|
|
def is_valid_eq(self, eq: list[int | Operator]) -> bool:
|
|
|
|
r = None
|
|
for i, token in enumerate(eq):
|
|
match token:
|
|
case Operator():
|
|
r = token.compute(r, eq[i + 1])
|
|
case int():
|
|
if not r:
|
|
r = token
|
|
|
|
return r == self.result
|
|
|
|
def is_valid(self, eq: list[int | Operator]) -> bool:
|
|
if not eq:
|
|
eq = []
|
|
for c in self.constants:
|
|
eq.append(c)
|
|
eq.append(' ')
|
|
eq = eq[:-1]
|
|
|
|
for i, c in enumerate(eq):
|
|
if c == ' ':
|
|
new_eq1 = eq.copy()
|
|
new_eq2 = eq.copy()
|
|
|
|
new_eq1[i] = Add()
|
|
new_eq2[i] = Multiply()
|
|
|
|
return self.is_valid(new_eq1) or self.is_valid(new_eq2)
|
|
|
|
return self.is_valid_eq(eq)
|
|
|
|
def is_valid2(self, eq: list[int | Operator]) -> bool:
|
|
if not eq:
|
|
eq = []
|
|
for c in self.constants:
|
|
eq.append(c)
|
|
eq.append(' ')
|
|
eq = eq[:-1]
|
|
|
|
for i, c in enumerate(eq):
|
|
if c == ' ':
|
|
new_eq1 = eq.copy()
|
|
new_eq2 = eq.copy()
|
|
new_eq3 = eq.copy()
|
|
|
|
new_eq1[i] = Add()
|
|
new_eq2[i] = Multiply()
|
|
new_eq3[i] = Concat()
|
|
|
|
return self.is_valid2(new_eq1) or self.is_valid2(new_eq2) or self.is_valid2(new_eq3)
|
|
|
|
return self.is_valid_eq(eq)
|
|
|
|
equations: list[Equation] = []
|
|
|
|
# Parse input
|
|
for line in data:
|
|
if ':' in line:
|
|
result, constants = line.split(':')
|
|
result = int(result)
|
|
constants = [int(n) for n in constants.strip().split(' ')]
|
|
equations.append(Equation(result, constants))
|
|
|
|
# Part 1
|
|
total = 0
|
|
for equation in equations:
|
|
if equation.is_valid(None):
|
|
total += equation.result
|
|
|
|
print(f'Part 1: {total}')
|
|
|
|
# Part 2
|
|
total = 0
|
|
for equation in equations:
|
|
if equation.is_valid2(None):
|
|
total += equation.result
|
|
|
|
print(f'Part 2: {total}')
|