DevQuiz スライドパズル 解答

本日、Google Developer Day 2011の参加資格をかけたDevQuizが終了しました。
というわけで、今年一番の難問だったスライドパズルの回答を掲載します。
結果から先に言うと、5000問中4562問正解でした。
が、特別なアルゴリズムを使っているわけでもなく、計算時間もとてつもなくかかるので、はっきりいってあまり参考にはならないと思います。
(MacBook Pro Core2 Duo 2.4GHzを使って6x6の問題を解くのに、1日で200問以下という遅さです。実際にはAmazon EC2を使って計算時間の遅さをカバーしました。)
ちなみに、使用言語はPythonです。

基本アルゴリズム

基本アルゴリズム幅優先探索です。
最初に、初期状態と空の入力列からなるタプル1つを要素にもつリストを用意します(変数名histories)。
次に、historiesの各要素から1手先の状態と入力列を計算し、それを再びリスト(histories)に格納します。
これを繰り返すことによって、解を探索していきます。
しかし、これだけではすぐにリストの長さが発散しメモリを使い尽くしてしまうので、重複の削除と枝刈りによってこれを回避します。

重複の削除

「右左右左」といった単純なループや同じ所をぐるぐる回るといったループを回避するために、過去の状態と入力列を全てデータベースに保存し、新しい解を得る度に過去の状態と比較を行います。
過去に同じ状態に辿りついたことがある場合は、そこで探索を終了します。

枝刈り

重複の削除を行ってもすぐに探索空間は発散してしまうので、枝刈りを行う必要があります。
histories中の各状態が最終状態にどれくらい近いかを評価し、より評価の良いものだけを残していきます。

この評価関数をどのように設計するかによって問題を解けるか解けないかが決定する訳ですが、評価関数をΣ距離×重みとすると、問題を解ける確率が飛躍的に高まりました。
ここで、「距離」はそのパネルの本来の位置までのユークリッド距離の二乗とし、「重み」は本来の位置が左上にあるパネルほど高い値をもつ数字です。
最終状態では動かせるマスが一番右下であるため、左上のパネルは早い段階で完成させる必要があると考え、このような評価関数を設定しました。

この評価関数は「壁」を考慮していないため、場合によっては局所最適に陥ってしまいうまく解けない問題もありましたが、最終的に約9割の問題が解けたことから、そこそこ良い評価関数ではないかと思っています。

高速化

高速化の工夫はほとんどありませんが、強いて挙げるとすれば、multiprocessingを使ってCPUのコアを使いきるような設計にしたことです。
multiprocessingは本当に便利なライブラリだと思います。

反省点

中途半端に問題が解けることを良い事に、遅いプログラムを最後まで使い続けたことが反省点です。
きちんとしたアルゴリズムを用いれば、同じ数の問題を数時間で解くことができることを知り、専門家の皆さんとの実力の違いを思い知りました。

ソース

import sys
import os
import re
import sqlite3
import multiprocessing
from contextlib import nested

WORK_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'work')
if not os.path.exists(WORK_DIR):
  os.mkdir(WORK_DIR)
DEBUG = False

def vadd(va, vb):
  return tuple((a + b for a, b in zip(va, vb)))

def vsub(va, vb):
  return tuple((a - b for a, b in zip(va, vb)))

def vabs(v):
  return sum([x*x for x in v])

class InputSequence(object):
  def __init__(self, l=0, r=0, u=0, d=0, seq=''):
    self.l   = l
    self.r   = r
    self.u   = u
    self.d   = d
    self.counts = (self.l, self.r, self.u, self.d)
    self.seq = seq

  def __str__(self):
    return self.seq

  def append(self, item):
    if item == 'L':
      return InputSequence(self.l+1, self.r, self.u, self.d, self.seq + 'L')
    elif item == 'R':
      return InputSequence(self.l, self.r+1, self.u, self.d, self.seq + 'R')
    elif item == 'U':
      return InputSequence(self.l, self.r, self.u+1, self.d, self.seq + 'U')
    elif item == 'D':
      return InputSequence(self.l, self.r, self.u, self.d+1, self.seq + 'D')
    raise RuntimeError

  def is_superior(self, other):
    return all([s <= o for s, o in zip(self.counts, other.counts)]) and (self.counts != other.counts)

  def is_inferior(self, other):
    return all([s >= o for s, o in zip(self.counts, other.counts)])

class State(object):
  def __init__(self, w, h, blank, state, final_state):
    self.width  = w
    self.height = h
    self.blank = blank
    self.state = state
    self.final_state = final_state

  def __str__(self):
    state_str = ''.join([''.join(row) for row in self.state])
    return ','.join([str(self.width), str(self.height), state_str])

  def norm(self):
    if hasattr(self, '_norm'):
      return self._norm

    char2coord = dict()
    for x in xrange(self.width):
      for y in xrange(self.height):
        if self.final_state[y][x] != '=':
          char2coord[self.final_state[y][x]] = (x, y)
    self._norm = 0
    for x in xrange(self.width):
      for y in xrange(self.height):
        if self.state[y][x] != '=':
          distance = vabs(vsub((x, y), char2coord[self.state[y][x]]))
          weight = sum(vsub((self.width-1, self.height-1), char2coord[self.state[y][x]]))
          self._norm += distance * weight
    return self._norm

  @staticmethod
  def parse_string(string):
    width, height, state_str = string.rstrip().split(',')
    width  = int(width)
    height = int(height)
    assert(len(state_str) == width * height)

    def pos2coord(pos):
      return (pos % width, pos / width)    
    blank = pos2coord(state_str.find('0'))
    walls = [pos2coord(m.start()) for m in re.compile('=').finditer(state_str)]
    final_state = State.get_final_state(width, height, walls)
    
    state = []
    state_list = list(state_str)
    for i in xrange(height):
      state.append(state_list[i*width:(i+1)*width])
        
    return State(width, height, blank, state, final_state)

  def is_finished(self):
    return self.state == self.final_state

  @staticmethod
  def get_final_state(width, height, walls):
    chars = [chr(ord('1') + i) for i in xrange(9)] + [chr(ord('A') + i) for i in xrange(26)]
    chars = chars[0:width*height-1] + ['0']
    result = []
    for i in xrange(height):
      result.append(chars[i*width:(i+1)*width])
    for x, y in walls:
      result[y][x] = '='
    return result

  def next_states(self):
    result = []
    directions = [('L', (-1, 0)),
                  ('R', (1,  0)),
                  ('U', (0, -1)),
                  ('D', (0,  1))]
    for direction_str,  direction in directions:
      next = self.move(direction)
      if next:
        result.append((next, direction_str))
    return result
  
  def move(self, direction):
    next_x, next_y = vadd(self.blank, direction)
    if (next_x >= 0 and
        next_y >= 0 and 
        next_x < self.width and
        next_y < self.height and
        self.state[next_y][next_x] != '='):
      return self.swap(next_x, next_y)
    else:
      return False

  def swap(self, next_x, next_y):
    x, y = self.blank
    state = [list(row) for row in self.state] # deep copy
    state[y][x], state[next_y][next_x] = self.state[next_y][next_x], self.state[y][x]
    return State(self.width, self.height, (next_x, next_y), state, self.final_state)

class Player(object):
  def __init__(self, problem, max_depth, max_histories):
    self.problem = problem
    self.max_depth = max_depth
    self.max_histories = max_histories
    self.finished = False
    
    db_path = os.path.join(WORK_DIR, problem + '.sqlite')
    db_exists = os.path.exists(db_path)
    self.db = sqlite3.connect(db_path)
    if not db_exists:
      self.initialize_db(self.db)
      
  def initialize_db(self, db):
    sql_table_create = """CREATE TABLE history(
                            state    TEXT,
                            finished INTEGER,
                            left     INTEGER,
                            right    INTEGER,
                            up       INTEGER,
                            down     INTEGER,
                            seq      TEXT);"""
    db.execute(sql_table_create)
    db.execute('CREATE INDEX history_state_index ON history(state);')
    db.execute('CREATE INDEX history_seq_index ON history(seq);')

  def __del__(self):
    if self.finished == True:
      self.db.execute('DELETE FROM history WHERE finished=0;')
      self.db.commit()
      self.db.execute('VACUUM;')
    self.db.close()

  def insert(self, state, input_sequence, finished):
    sql = 'INSERT INTO history values (?,?,?,?,?,?,?);'
    if finished:      
      finished_flag = 1
    else:
      finished_flag = 0
    self.db.execute(sql, (str(state),
                          finished_flag,
                          input_sequence.l,
                          input_sequence.r,
                          input_sequence.u,
                          input_sequence.d,
                          str(input_sequence))
                    )
    
  def delete(self, state, input_sequence):
    sql = 'DELETE FROM history WHERE state=? AND seq=?;'
    self.db.execute(sql, ((str(state), str(input_sequence))))

  def start(self):
    def norm_cmp(h1, h2):
      s1 = h1[0]
      s2 = h2[0]
      norm1 = s1.norm()
      norm2 = s2.norm()
      return norm1 - norm2
    
    state = State.parse_string(self.problem)
    input_sequence = InputSequence()
    if state.is_finished() == True:
      self.insert(state, input_sequence, True)
    else:
      self.insert(state, input_sequence, False)
      histories = [(state, input_sequence)]

      for i in xrange(self.max_depth):
        if self.finished:
          break
        if len(histories) > self.max_histories:
          histories.sort(norm_cmp)
          histories = histories[0:self.max_histories]            
        if DEBUG == True:
          print ('i = %d, len(histories) = %d' %
                 (i, len(histories)))
        histories = list(self.get_next_histories(histories))
        if i % 10 == 0:
          self.db.commit()

  def get_next_histories(self, histories):
    for state, input_sequence in histories:
      for next_state, direction in state.next_states():
        next_input_sequence = input_sequence.append(direction)
        if not self.check_duplication(next_state, next_input_sequence):
          if next_state.is_finished() == True:
            self.insert(next_state, next_input_sequence, True)
            self.finished = True
          else:
            self.insert(next_state, next_input_sequence, False)
            yield (next_state, next_input_sequence)

  def check_duplication(self, next_state, next_input_sequence):
    for duplication in self.get_duplications(next_state):
      if next_input_sequence.is_inferior(duplication):
        return True
      elif next_input_sequence.is_superior(duplication):
        self.delete(next_state, duplication)
    return False

  def get_duplications(self, state):
    cursor = self.db.cursor()
    cursor.execute('SELECT left,right,up,down,seq FROM history WHERE state=?', (str(state), ))
    for row in cursor:
      yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4])

  def get_finished(self, limit):
    cursor = self.db.cursor()
    cursor.execute('SELECT left,right,up,down,seq FROM history WHERE finished=1 LIMIT ? OFFSET 0;', (limit,))
    for row in cursor:
      yield InputSequence(int(row[0]), int(row[1]),int(row[2]),int(row[3]), row[4])

  def get_result(self):
    finished = list(self.get_finished(1))
    if len(finished) == 0:
      self.start()
      finished = list(self.get_finished(1))
      if len(finished) == 0:
        print 'Fail: %s' % self.problem
        return ''
    return str(finished[0])

def slide_puzzle(problem):
  return Player(problem, max_depth=100, max_histories=10000).get_result()

def main():
  if len(sys.argv) < 3:
    print 'usage: python %s [problems_file] [output_file]' % __file__
    return 0
  with nested(open(sys.argv[1], 'r'), open(sys.argv[2], 'w')) as (fin, fout):
    limits = [int(x) for x in fin.readline().rstrip().split()]
    num_problems = int(fin.readline().rstrip())
    problems = [line.rstrip() for line in fin.readlines()]
    assert(len(problems) == num_problems)
    processes = multiprocessing.cpu_count() + 1
    pool = multiprocessing.Pool(processes=processes)
    fout.write('\n'.join(pool.map(slide_puzzle, problems)))
    fout.write('\n')
  return 0

if __name__ == '__main__':
  sys.exit(main())