Rubyメソッドのアラウンドエイリアスについて

Rubyで末尾再帰最適化をする。を読んで思ったこと。
このエントリやメタプログラミングRubyでメソッドのアラウンドエイリアスを定義する際に、古いメソッドに別名をつけてから新しいメソッドを定義しています。以下はメタプログラミングRubyの例。

class String
  alias :real_length :length

  def length
    real_length > 5 ? 'long' : 'short'
  end
end

この方法の問題点はエイリアスしたメソッド(この例の場合real_length)が残ってしまうことです。メソッドオブジェクトを使えばエイリアスしたメソッドを残すことなくアラウンドエイリアスを行うことができると思ったので、実際にやってみました。

以下のプログラムは、Rubyで末尾再帰最適化をする。のプログラムをメソッドエイリアスではなくメソッドオブジェクトを使って書いた例です。

#!/usr/bin/ruby -Ku
# -*- coding: utf-8 -*-

class Module
  def tco(name)
    continue = []
    first = true
    arguments = nil
    original_method = self.instance_method(name)

    proc = lambda do |*args|
      if first
        first = false
        while true
          result = original_method.bind(self).call(*args)
          if result.equal? continue
            args = arguments
          else
            first = true
            return result
          end
        end
      else
        arguments = args
        continue
      end
    end
    define_method name, proc
  end
end

class Sum
  def sum1(n, acc=0)
    if n == 0
      acc
    else
      sum1(n-1, acc+n)
    end
  end

  def sum2(n, acc=0)
    if n == 0
      acc
    else
      sum2(n-1, acc+n)
    end
  end
  tco :sum2
end

if $0 == __FILE__
  o = Sum.new
  p o.sum2(100000)
  p o.sum1(100000)
end

実行結果

> ./tco.rb 
5000050000
./tco.rb:37:in `sum1': stack level too deep (SystemStackError)
	from ./tco.rb:37:in `sum1'
	from ./tco.rb:54
  original_method = self.instance_method(name)

の部分で元のメソッドのUnboundMethodオブジェクトを取り出して、

  result = original_method.bind(self).call(*args)

で元のメソッドを呼び出しています。

このエントリを書いてる途中に、rubikitchさんが同じことをもっと詳しく解説してくれているのを発見したので、リンクを張っておきます。

ついでに、末尾再帰最適化のPythonの例を読んだときに、ネタに対して、「これってマルチスレッドで動かないよね」と思ってしまったので、Pythonの末尾再帰最適化のマルチスレッド対応版を書いておきます。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

class tail_recursive(object):
  def __init__(self, func):
    self.func = func
    self.firstcall = True
    self.CONTINUE = object()

  def __call__(self, *args, **kwd):
    if self.firstcall:
      func = self.func
      CONTINUE = self.CONTINUE
      self.firstcall = False
      try:
        while True:
          result = func(*args, **kwd)
          if result is CONTINUE: # update arguments
            args, kwd = self.argskwd
          else: # last call
            return result
      finally:
        self.firstcall = True
    else: # return the arguments of the tail call
      self.argskwd = args, kwd
      return self.CONTINUE

import threading
class tail_recursive_mt(object):
  def __init__(self, func):
    self.func = func
    self.local = threading.local()
    self.CONTINUE = object()

  def __call__(self, *args, **kwd):
    if not hasattr(self.local, "firstcall") or self.local.firstcall:
      func = self.func
      CONTINUE = self.CONTINUE
      self.local.firstcall = False
      try:
        while True:
          result = func(*args, **kwd)
          if result is CONTINUE: # update arguments
            args, kwd = self.local.argskwd
          else: # last call
            return result
      finally:
        self.local.firstcall = True
    else: # return the arguments of the tail call
      self.local.argskwd = args, kwd
      return self.CONTINUE    

@tail_recursive
def sum(n, acc=0):
  if n == 0:
    return acc
  else:
    return sum(n-1, acc+n)

@tail_recursive_mt
def sum_mt(n, acc=0):
  if n == 0:
    return acc
  else:
    return sum_mt(n-1, acc+n)  

class SumTest(threading.Thread):
  def run(self):
    try:
      assert(sum(100000) == 5000050000)
    except AssertionError, e:
      print "AssertionError"
    else:
      print "OK"

class SumMtTest(threading.Thread):
  def run(self):
    try:
      assert(sum_mt(100000) == 5000050000)
    except AssertionError, e:
      print "AssertionError"
    else:
      print "OK"
    
if __name__ == '__main__':
  print "="*60
  print "Single Thread Version"
  print "="*60
  threads = [SumTest() for i in xrange(5)]
  for t in threads:
    t.start()
  for t in threads:
    t.join()

  print "="*60
  print "Multi Thread Version"
  print "="*60
  threads = [SumMtTest() for i in xrange(5)]
  for t in threads:
    t.start()
  for t in threads:
    t.join()

実行結果

> ./tco.py
============================================================
Single Thread Version
============================================================
AssertionError
AssertionError
AssertionError
AssertionError
OK
============================================================
Multi Thread Version
============================================================
OK
OK
OK
OK
OK