DevQuiz スライドパズル 解答
本日、Google Developer Day 2011の参加資格をかけたDevQuizが終了しました。
というわけで、今年一番の難問だったスライドパズルの回答を掲載します。
結果から先に言うと、5000問中4562問正解でした。
が、特別なアルゴリズムを使っているわけでもなく、計算時間もとてつもなくかかるので、はっきりいってあまり参考にはならないと思います。
(MacBook Pro Core2 Duo 2.4GHzを使って6x6の問題を解くのに、1日で200問以下という遅さです。実際にはAmazon EC2を使って計算時間の遅さをカバーしました。)
ちなみに、使用言語はPythonです。
基本アルゴリズム
基本アルゴリズムは幅優先探索です。
最初に、初期状態と空の入力列からなるタプル1つを要素にもつリストを用意します(変数名histories)。
次に、historiesの各要素から1手先の状態と入力列を計算し、それを再びリスト(histories)に格納します。
これを繰り返すことによって、解を探索していきます。
しかし、これだけではすぐにリストの長さが発散しメモリを使い尽くしてしまうので、重複の削除と枝刈りによってこれを回避します。
重複の削除
「右左右左」といった単純なループや同じ所をぐるぐる回るといったループを回避するために、過去の状態と入力列を全てデータベースに保存し、新しい解を得る度に過去の状態と比較を行います。
過去に同じ状態に辿りついたことがある場合は、そこで探索を終了します。
枝刈り
重複の削除を行ってもすぐに探索空間は発散してしまうので、枝刈りを行う必要があります。
histories中の各状態が最終状態にどれくらい近いかを評価し、より評価の良いものだけを残していきます。
この評価関数をどのように設計するかによって問題を解けるか解けないかが決定する訳ですが、評価関数をΣ距離×重みとすると、問題を解ける確率が飛躍的に高まりました。
ここで、「距離」はそのパネルの本来の位置までのユークリッド距離の二乗とし、「重み」は本来の位置が左上にあるパネルほど高い値をもつ数字です。
最終状態では動かせるマスが一番右下であるため、左上のパネルは早い段階で完成させる必要があると考え、このような評価関数を設定しました。
この評価関数は「壁」を考慮していないため、場合によっては局所最適に陥ってしまいうまく解けない問題もありましたが、最終的に約9割の問題が解けたことから、そこそこ良い評価関数ではないかと思っています。
高速化
高速化の工夫はほとんどありませんが、強いて挙げるとすれば、multiprocessingを使ってCPUのコアを使いきるような設計にしたことです。
multiprocessingは本当に便利なライブラリだと思います。
反省点
中途半端に問題が解けることを良い事に、遅いプログラムを最後まで使い続けたことが反省点です。
きちんとしたアルゴリズムを用いれば、同じ数の問題を数時間で解くことができることを知り、専門家の皆さんとの実力の違いを思い知りました。
ソース
import sys import os import re import sqlite3 import multiprocessing from contextlib import nested WORK_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'work') if not os.path.exists(WORK_DIR): os.mkdir(WORK_DIR) DEBUG = False def vadd(va, vb): return tuple((a + b for a, b in zip(va, vb))) def vsub(va, vb): return tuple((a - b for a, b in zip(va, vb))) def vabs(v): return sum([x*x for x in v]) class InputSequence(object): def __init__(self, l=0, r=0, u=0, d=0, seq=''): self.l = l self.r = r self.u = u self.d = d self.counts = (self.l, self.r, self.u, self.d) self.seq = seq def __str__(self): return self.seq def append(self, item): if item == 'L': return InputSequence(self.l+1, self.r, self.u, self.d, self.seq + 'L') elif item == 'R': return InputSequence(self.l, self.r+1, self.u, self.d, self.seq + 'R') elif item == 'U': return InputSequence(self.l, self.r, self.u+1, self.d, self.seq + 'U') elif item == 'D': return InputSequence(self.l, self.r, self.u, self.d+1, self.seq + 'D') raise RuntimeError def is_superior(self, other): return all([s <= o for s, o in zip(self.counts, other.counts)]) and (self.counts != other.counts) def is_inferior(self, other): return all([s >= o for s, o in zip(self.counts, other.counts)]) class State(object): def __init__(self, w, h, blank, state, final_state): self.width = w self.height = h self.blank = blank self.state = state self.final_state = final_state def __str__(self): state_str = ''.join([''.join(row) for row in self.state]) return ','.join([str(self.width), str(self.height), state_str]) def norm(self): if hasattr(self, '_norm'): return self._norm char2coord = dict() for x in xrange(self.width): for y in xrange(self.height): if self.final_state[y][x] != '=': char2coord[self.final_state[y][x]] = (x, y) self._norm = 0 for x in xrange(self.width): for y in xrange(self.height): if self.state[y][x] != '=': distance = vabs(vsub((x, y), char2coord[self.state[y][x]])) weight = sum(vsub((self.width-1, self.height-1), char2coord[self.state[y][x]])) self._norm += distance * weight return self._norm @staticmethod def parse_string(string): width, height, state_str = string.rstrip().split(',') width = int(width) height = int(height) assert(len(state_str) == width * height) def pos2coord(pos): return (pos % width, pos / width) blank = pos2coord(state_str.find('0')) walls = [pos2coord(m.start()) for m in re.compile('=').finditer(state_str)] final_state = State.get_final_state(width, height, walls) state = [] state_list = list(state_str) for i in xrange(height): state.append(state_list[i*width:(i+1)*width]) return State(width, height, blank, state, final_state) def is_finished(self): return self.state == self.final_state @staticmethod def get_final_state(width, height, walls): chars = [chr(ord('1') + i) for i in xrange(9)] + [chr(ord('A') + i) for i in xrange(26)] chars = chars[0:width*height-1] + ['0'] result = [] for i in xrange(height): result.append(chars[i*width:(i+1)*width]) for x, y in walls: result[y][x] = '=' return result def next_states(self): result = [] directions = [('L', (-1, 0)), ('R', (1, 0)), ('U', (0, -1)), ('D', (0, 1))] for direction_str, direction in directions: next = self.move(direction) if next: result.append((next, direction_str)) return result def move(self, direction): next_x, next_y = vadd(self.blank, direction) if (next_x >= 0 and next_y >= 0 and next_x < self.width and next_y < self.height and self.state[next_y][next_x] != '='): return self.swap(next_x, next_y) else: return False def swap(self, next_x, next_y): x, y = self.blank state = [list(row) for row in self.state] # deep copy state[y][x], state[next_y][next_x] = self.state[next_y][next_x], self.state[y][x] return State(self.width, self.height, (next_x, next_y), state, self.final_state) class Player(object): def __init__(self, problem, max_depth, max_histories): self.problem = problem self.max_depth = max_depth self.max_histories = max_histories self.finished = False db_path = os.path.join(WORK_DIR, problem + '.sqlite') db_exists = os.path.exists(db_path) self.db = sqlite3.connect(db_path) if not db_exists: self.initialize_db(self.db) def initialize_db(self, db): sql_table_create = """CREATE TABLE history( state TEXT, finished INTEGER, left INTEGER, right INTEGER, up INTEGER, down INTEGER, seq TEXT);""" db.execute(sql_table_create) db.execute('CREATE INDEX history_state_index ON history(state);') db.execute('CREATE INDEX history_seq_index ON history(seq);') def __del__(self): if self.finished == True: self.db.execute('DELETE FROM history WHERE finished=0;') self.db.commit() self.db.execute('VACUUM;') self.db.close() def insert(self, state, input_sequence, finished): sql = 'INSERT INTO history values (?,?,?,?,?,?,?);' if finished: finished_flag = 1 else: finished_flag = 0 self.db.execute(sql, (str(state), finished_flag, input_sequence.l, input_sequence.r, input_sequence.u, input_sequence.d, str(input_sequence)) ) def delete(self, state, input_sequence): sql = 'DELETE FROM history WHERE state=? AND seq=?;' self.db.execute(sql, ((str(state), str(input_sequence)))) def start(self): def norm_cmp(h1, h2): s1 = h1[0] s2 = h2[0] norm1 = s1.norm() norm2 = s2.norm() return norm1 - norm2 state = State.parse_string(self.problem) input_sequence = InputSequence() if state.is_finished() == True: self.insert(state, input_sequence, True) else: self.insert(state, input_sequence, False) histories = [(state, input_sequence)] for i in xrange(self.max_depth): if self.finished: break if len(histories) > self.max_histories: histories.sort(norm_cmp) histories = histories[0:self.max_histories] if DEBUG == True: print ('i = %d, len(histories) = %d' % (i, len(histories))) histories = list(self.get_next_histories(histories)) if i % 10 == 0: self.db.commit() def get_next_histories(self, histories): for state, input_sequence in histories: for next_state, direction in state.next_states(): next_input_sequence = input_sequence.append(direction) if not self.check_duplication(next_state, next_input_sequence): if next_state.is_finished() == True: self.insert(next_state, next_input_sequence, True) self.finished = True else: self.insert(next_state, next_input_sequence, False) yield (next_state, next_input_sequence) def check_duplication(self, next_state, next_input_sequence): for duplication in self.get_duplications(next_state): if next_input_sequence.is_inferior(duplication): return True elif next_input_sequence.is_superior(duplication): self.delete(next_state, duplication) return False def get_duplications(self, state): cursor = self.db.cursor() cursor.execute('SELECT left,right,up,down,seq FROM history WHERE state=?', (str(state), )) for row in cursor: yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4]) def get_finished(self, limit): cursor = self.db.cursor() cursor.execute('SELECT left,right,up,down,seq FROM history WHERE finished=1 LIMIT ? OFFSET 0;', (limit,)) for row in cursor: yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4]) def get_result(self): finished = list(self.get_finished(1)) if len(finished) == 0: self.start() finished = list(self.get_finished(1)) if len(finished) == 0: print 'Fail: %s' % self.problem return '' return str(finished[0]) def slide_puzzle(problem): return Player(problem, max_depth=100, max_histories=10000).get_result() def main(): if len(sys.argv) < 3: print 'usage: python %s [problems_file] [output_file]' % __file__ return 0 with nested(open(sys.argv[1], 'r'), open(sys.argv[2], 'w')) as (fin, fout): limits = [int(x) for x in fin.readline().rstrip().split()] num_problems = int(fin.readline().rstrip()) problems = [line.rstrip() for line in fin.readlines()] assert(len(problems) == num_problems) processes = multiprocessing.cpu_count() + 1 pool = multiprocessing.Pool(processes=processes) fout.write('\n'.join(pool.map(slide_puzzle, problems))) fout.write('\n') return 0 if __name__ == '__main__': sys.exit(main())