DevQuiz スライドパズル 解答
本日、Google Developer Day 2011の参加資格をかけたDevQuizが終了しました。
(MacBook Pro Core2 Duo 2.4GHzを使って6x6の問題を解くのに、1日で200問以下という遅さです。実際にはAmazon EC2を使って計算時間の遅さをカバーしました。)
import sys import os import re import sqlite3 import multiprocessing from contextlib import nested WORK_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'work') if not os.path.exists(WORK_DIR): os.mkdir(WORK_DIR) DEBUG = False def vadd(va, vb): return tuple((a + b for a, b in zip(va, vb))) def vsub(va, vb): return tuple((a - b for a, b in zip(va, vb))) def vabs(v): return sum([x*x for x in v]) class InputSequence(object): def __init__(self, l=0, r=0, u=0, d=0, seq=''): self.l = l self.r = r self.u = u self.d = d self.counts = (self.l, self.r, self.u, self.d) self.seq = seq def __str__(self): return self.seq def append(self, item): if item == 'L': return InputSequence(self.l+1, self.r, self.u, self.d, self.seq + 'L') elif item == 'R': return InputSequence(self.l, self.r+1, self.u, self.d, self.seq + 'R') elif item == 'U': return InputSequence(self.l, self.r, self.u+1, self.d, self.seq + 'U') elif item == 'D': return InputSequence(self.l, self.r, self.u, self.d+1, self.seq + 'D') raise RuntimeError def is_superior(self, other): return all([s <= o for s, o in zip(self.counts, other.counts)]) and (self.counts != other.counts) def is_inferior(self, other): return all([s >= o for s, o in zip(self.counts, other.counts)]) class State(object): def __init__(self, w, h, blank, state, final_state): self.width = w self.height = h self.blank = blank self.state = state self.final_state = final_state def __str__(self): state_str = ''.join([''.join(row) for row in self.state]) return ','.join([str(self.width), str(self.height), state_str]) def norm(self): if hasattr(self, '_norm'): return self._norm char2coord = dict() for x in xrange(self.width): for y in xrange(self.height): if self.final_state[y][x] != '=': char2coord[self.final_state[y][x]] = (x, y) self._norm = 0 for x in xrange(self.width): for y in xrange(self.height): if self.state[y][x] != '=': distance = vabs(vsub((x, y), char2coord[self.state[y][x]])) weight = sum(vsub((self.width-1, self.height-1), char2coord[self.state[y][x]])) self._norm += distance * weight return self._norm @staticmethod def parse_string(string): width, height, state_str = string.rstrip().split(',') width = int(width) height = int(height) assert(len(state_str) == width * height) def pos2coord(pos): return (pos % width, pos / width) blank = pos2coord(state_str.find('0')) walls = [pos2coord(m.start()) for m in re.compile('=').finditer(state_str)] final_state = State.get_final_state(width, height, walls) state = [] state_list = list(state_str) for i in xrange(height): state.append(state_list[i*width:(i+1)*width]) return State(width, height, blank, state, final_state) def is_finished(self): return self.state == self.final_state @staticmethod def get_final_state(width, height, walls): chars = [chr(ord('1') + i) for i in xrange(9)] + [chr(ord('A') + i) for i in xrange(26)] chars = chars[0:width*height-1] + ['0'] result = [] for i in xrange(height): result.append(chars[i*width:(i+1)*width]) for x, y in walls: result[y][x] = '=' return result def next_states(self): result = [] directions = [('L', (-1, 0)), ('R', (1, 0)), ('U', (0, -1)), ('D', (0, 1))] for direction_str, direction in directions: next = self.move(direction) if next: result.append((next, direction_str)) return result def move(self, direction): next_x, next_y = vadd(self.blank, direction) if (next_x >= 0 and next_y >= 0 and next_x < self.width and next_y < self.height and self.state[next_y][next_x] != '='): return self.swap(next_x, next_y) else: return False def swap(self, next_x, next_y): x, y = self.blank state = [list(row) for row in self.state] # deep copy state[y][x], state[next_y][next_x] = self.state[next_y][next_x], self.state[y][x] return State(self.width, self.height, (next_x, next_y), state, self.final_state) class Player(object): def __init__(self, problem, max_depth, max_histories): self.problem = problem self.max_depth = max_depth self.max_histories = max_histories self.finished = False db_path = os.path.join(WORK_DIR, problem + '.sqlite') db_exists = os.path.exists(db_path) self.db = sqlite3.connect(db_path) if not db_exists: self.initialize_db(self.db) def initialize_db(self, db): sql_table_create = """CREATE TABLE history( state TEXT, finished INTEGER, left INTEGER, right INTEGER, up INTEGER, down INTEGER, seq TEXT);""" db.execute(sql_table_create) db.execute('CREATE INDEX history_state_index ON history(state);') db.execute('CREATE INDEX history_seq_index ON history(seq);') def __del__(self): if self.finished == True: self.db.execute('DELETE FROM history WHERE finished=0;') self.db.commit() self.db.execute('VACUUM;') self.db.close() def insert(self, state, input_sequence, finished): sql = 'INSERT INTO history values (?,?,?,?,?,?,?);' if finished: finished_flag = 1 else: finished_flag = 0 self.db.execute(sql, (str(state), finished_flag, input_sequence.l, input_sequence.r, input_sequence.u, input_sequence.d, str(input_sequence)) ) def delete(self, state, input_sequence): sql = 'DELETE FROM history WHERE state=? AND seq=?;' self.db.execute(sql, ((str(state), str(input_sequence)))) def start(self): def norm_cmp(h1, h2): s1 = h1[0] s2 = h2[0] norm1 = s1.norm() norm2 = s2.norm() return norm1 - norm2 state = State.parse_string(self.problem) input_sequence = InputSequence() if state.is_finished() == True: self.insert(state, input_sequence, True) else: self.insert(state, input_sequence, False) histories = [(state, input_sequence)] for i in xrange(self.max_depth): if self.finished: break if len(histories) > self.max_histories: histories.sort(norm_cmp) histories = histories[0:self.max_histories] if DEBUG == True: print ('i = %d, len(histories) = %d' % (i, len(histories))) histories = list(self.get_next_histories(histories)) if i % 10 == 0: self.db.commit() def get_next_histories(self, histories): for state, input_sequence in histories: for next_state, direction in state.next_states(): next_input_sequence = input_sequence.append(direction) if not self.check_duplication(next_state, next_input_sequence): if next_state.is_finished() == True: self.insert(next_state, next_input_sequence, True) self.finished = True else: self.insert(next_state, next_input_sequence, False) yield (next_state, next_input_sequence) def check_duplication(self, next_state, next_input_sequence): for duplication in self.get_duplications(next_state): if next_input_sequence.is_inferior(duplication): return True elif next_input_sequence.is_superior(duplication): self.delete(next_state, duplication) return False def get_duplications(self, state): cursor = self.db.cursor() cursor.execute('SELECT left,right,up,down,seq FROM history WHERE state=?', (str(state), )) for row in cursor: yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4]) def get_finished(self, limit): cursor = self.db.cursor() cursor.execute('SELECT left,right,up,down,seq FROM history WHERE finished=1 LIMIT ? OFFSET 0;', (limit,)) for row in cursor: yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4]) def get_result(self): finished = list(self.get_finished(1)) if len(finished) == 0: self.start() finished = list(self.get_finished(1)) if len(finished) == 0: print 'Fail: %s' % self.problem return '' return str(finished[0]) def slide_puzzle(problem): return Player(problem, max_depth=100, max_histories=10000).get_result() def main(): if len(sys.argv) < 3: print 'usage: python %s [problems_file] [output_file]' % __file__ return 0 with nested(open(sys.argv[1], 'r'), open(sys.argv[2], 'w')) as (fin, fout): limits = [int(x) for x in fin.readline().rstrip().split()] num_problems = int(fin.readline().rstrip()) problems = [line.rstrip() for line in fin.readlines()] assert(len(problems) == num_problems) processes = multiprocessing.cpu_count() + 1 pool = multiprocessing.Pool(processes=processes) fout.write('\n'.join(, problems))) fout.write('\n') return 0 if __name__ == '__main__': sys.exit(main())