Rubyメソッドのアラウンドエイリアスについて
Rubyで末尾再帰最適化をする。を読んで思ったこと。
このエントリやメタプログラミングRubyでメソッドのアラウンドエイリアスを定義する際に、古いメソッドに別名をつけてから新しいメソッドを定義しています。以下はメタプログラミングRubyの例。
class String alias :real_length :length def length real_length > 5 ? 'long' : 'short' end end
この方法の問題点はエイリアスしたメソッド(この例の場合real_length)が残ってしまうことです。メソッドオブジェクトを使えばエイリアスしたメソッドを残すことなくアラウンドエイリアスを行うことができると思ったので、実際にやってみました。
以下のプログラムは、Rubyで末尾再帰最適化をする。のプログラムをメソッドエイリアスではなくメソッドオブジェクトを使って書いた例です。
#!/usr/bin/ruby -Ku # -*- coding: utf-8 -*- class Module def tco(name) continue = [] first = true arguments = nil original_method = self.instance_method(name) proc = lambda do |*args| if first first = false while true result = original_method.bind(self).call(*args) if result.equal? continue args = arguments else first = true return result end end else arguments = args continue end end define_method name, proc end end class Sum def sum1(n, acc=0) if n == 0 acc else sum1(n-1, acc+n) end end def sum2(n, acc=0) if n == 0 acc else sum2(n-1, acc+n) end end tco :sum2 end if $0 == __FILE__ o = Sum.new p o.sum2(100000) p o.sum1(100000) end
実行結果
> ./tco.rb 5000050000 ./tco.rb:37:in `sum1': stack level too deep (SystemStackError) from ./tco.rb:37:in `sum1' from ./tco.rb:54
original_method = self.instance_method(name)
の部分で元のメソッドのUnboundMethodオブジェクトを取り出して、
result = original_method.bind(self).call(*args)
で元のメソッドを呼び出しています。
このエントリを書いてる途中に、rubikitchさんが同じことをもっと詳しく解説してくれているのを発見したので、リンクを張っておきます。
ついでに、末尾再帰最適化のPythonの例を読んだときに、ネタに対して、「これってマルチスレッドで動かないよね」と思ってしまったので、Pythonの末尾再帰最適化のマルチスレッド対応版を書いておきます。
#!/usr/bin/env python # -*- coding: utf-8 -*- class tail_recursive(object): def __init__(self, func): self.func = func self.firstcall = True self.CONTINUE = object() def __call__(self, *args, **kwd): if self.firstcall: func = self.func CONTINUE = self.CONTINUE self.firstcall = False try: while True: result = func(*args, **kwd) if result is CONTINUE: # update arguments args, kwd = self.argskwd else: # last call return result finally: self.firstcall = True else: # return the arguments of the tail call self.argskwd = args, kwd return self.CONTINUE import threading class tail_recursive_mt(object): def __init__(self, func): self.func = func self.local = threading.local() self.CONTINUE = object() def __call__(self, *args, **kwd): if not hasattr(self.local, "firstcall") or self.local.firstcall: func = self.func CONTINUE = self.CONTINUE self.local.firstcall = False try: while True: result = func(*args, **kwd) if result is CONTINUE: # update arguments args, kwd = self.local.argskwd else: # last call return result finally: self.local.firstcall = True else: # return the arguments of the tail call self.local.argskwd = args, kwd return self.CONTINUE @tail_recursive def sum(n, acc=0): if n == 0: return acc else: return sum(n-1, acc+n) @tail_recursive_mt def sum_mt(n, acc=0): if n == 0: return acc else: return sum_mt(n-1, acc+n) class SumTest(threading.Thread): def run(self): try: assert(sum(100000) == 5000050000) except AssertionError, e: print "AssertionError" else: print "OK" class SumMtTest(threading.Thread): def run(self): try: assert(sum_mt(100000) == 5000050000) except AssertionError, e: print "AssertionError" else: print "OK" if __name__ == '__main__': print "="*60 print "Single Thread Version" print "="*60 threads = [SumTest() for i in xrange(5)] for t in threads: t.start() for t in threads: t.join() print "="*60 print "Multi Thread Version" print "="*60 threads = [SumMtTest() for i in xrange(5)] for t in threads: t.start() for t in threads: t.join()
実行結果
> ./tco.py ============================================================ Single Thread Version ============================================================ AssertionError AssertionError AssertionError AssertionError OK ============================================================ Multi Thread Version ============================================================ OK OK OK OK OK