I am trying to write a macro that will memoize a function, possibly recursive (here, I am not taking into account more complex situations, like a function that would call itself via mutual recursion and so on).
My tentative implementation goes like this:
import macros, tables
proc memoize[A, B](f: proc(a: A): B): proc(a: A): B =
var cache = initTable[A, B]()
proc g(a: A): B =
if cache.hasKey(a): return cache[a]
else:
let b = f(a)
cache[a] = b
return b
return g
macro memo(e: expr): expr =
let f = !"internalName" # genSym()
let selfName = e[0].ident
iterator pairs(it: NimNode): expr =
var count = -1
for x in it.children:
count += 1
yield (count, x)
proc recSub(node) =
for c in children(node):
recSub(c)
if node.kind == nnkIdent:
if node.ident == selfName:
node.ident = f
for i, c in e:
if i != 0:
recSub(c)
result = quote do:
var internalName: proc (a: int): int
`e`
internalName = memoize(`selfName`)
internalName
echo toStrLit(result)
proc fib(n: int): int {.memo.} =
if n <= 1: n
else: fib(n - 1) + fib(n - 2)
when isMainModule:
echo(fib(40))
Here memoize is a higher-oder function that takes a function and returns its memoized version. memo is a macro that tries to substitute self-calls with calls to the memoized version.
The problem is that the symbol is not resolved, although dumping the toStrLit version of the returned function makes it look like it should be resolved correctly.
Other minor problems that I have are:
I'm not sure why your code is not working. At first glance, I think it should be OK.
I tried it another way. The code below will not win a beauty contest, but seems to work.
import tables, macros
macro memoize(e: expr): stmt =
let
procName = e.name()
procBody = e.body()
retType = e.params()[0]
param1 = e.params()[1]
n = param1[0]
nType = param1[1]
cache = ident($procName & "_cache")
funName = ident($procName & "_original")
quote do:
var `cache` = initTable[`nType`,`retType`]()
proc `procName`(`n` : `nType`) : `retType`
proc `funName`(`n` : `nType`) : `retType` =
`procBody`
proc `procName`(`n` : `nType`) : `retType` =
if not `cache`.hasKey(`n`):
`cache`[`n`] = `funName`(`n`)
return `cache`[`n`]
proc fib(n : int) : int {.memoize.} =
if n < 2:
n
else:
fib(n-1) + fib(n-2)
echo fib(44)
Using your higher-order memoize function also works using the macro below.
I also tried the genSym function, but it seems to return the same symbol every time. It's unclear to me if it is broken or if I'm using it wrongly.
macro memo(e: expr): stmt =
let
procName = e.name()
procBody = e.body()
retType = e.params()[0]
param1 = e.params()[1]
n = param1[0]
nType = param1[1]
funName = ident($procName & "_original")
quote do:
var `procName` : proc(`n`:`nType`):`retType`
proc `funName`(`n` : `nType`) : `retType` =
`procBody`
`procName` = memoize(`funName`)
@Araq: Thanks for pointing out the template + getAst way. That works nice indeed. Underneath an initial attempt for the memoize macro using this trick.
The genSym does work indeed. I did not see a difference when printing out the AST and wrongly assumed the symbols where the same. But with the template it's no longer needed.
import tables, macros
macro memoize(e: expr): stmt =
template memo(n, nType, retType, procName, procBody : expr): stmt =
var cache = initTable[nType,retType]()
proc procName(n : nType) : retType
proc funName(n : nType) : retType {.gensym.} =
procBody
proc procName(n : nType) : retType =
if not cache.hasKey(n):
cache[n] = funName(n)
return cache[n]
let
retType = e.params()[0]
param = e.params()[1]
getAst(memo(param[0], param[1], retType, e.name(), e.body()))
proc fib(n : int) : int {.memoize.} =
if n < 2:
n
else:
fib(n-1) + fib(n-2)
proc fac(n : int) : int {.memoize.} =
if n < 2:
n
else:
n * fac(n-1)
echo fib(66)
echo fac(16)
Very interesting discussion! So, if I understand correctly, the reason why my initial attempt does not work is that the internal function is not available in the scope of the original one - due to quote generating a hygenic template - and I would need to put a {.gensym.} annotation to make it visible.
@wiffel: would you mind if I published your macro under nimble? It is just a few lines, but it is not trivial to get it right (as this thread shows) and it is a small utility that can often come handy. Or, if you prefer, you could publish it yourself.
@andrea: I think your explanation is not completely correct. The {.gensym.} annotation works the other way around. It hides the definition from the calling context. {.inject.} is what you would need to make it visible. But - as far as I understand from the documentation - both annotations are only available in a template. In your macro, you use a var to reference the internalName proc. In a template, a var defaults to {.gensym.}, which hides it.
@andrea: Feel free to publish this code example, but it's not a general solution yet. It only works with 1 parameter. Secondly, it does not work with a parameter type that isn't valid as a key for the table cache. E.g. it will not work with bigints out of the box. (Adding a hash function for BigInt fixes that)
Thank you for the clarification!
About the generality: I think it is not a big restriction that the key type has to be hashable. After all, to memoize a function, one has to remember the previous results, and usually a hashtable is the right place for this. An alternative would be a treemap, but even that would need an ordering on the keys anyway.
The fact that the solution only works for functions of a single argument is slightly more inconvenient, but one can always transform a function of multiple arguments into a function of a single tuple, and than memoize that (and possibly wrap it into a function of > 1 arguments for caller convenience).
What I find more annoying is the fact that I do not see a way to handle recursive functions where the recursion is not via a self call, but via a chain of calls through multiple functions. Even two functions that call one another fail to be memoized (although one can memoize just one of the two). I think I will publish this and put out a more general solution if some better idea comes out
import tables, macros
macro memoize(e: expr): stmt =
template memo(n, nT, returnT, procName, procBody: expr): stmt =
var cache = initTable[nT, returnT]()
when not declared(procName):
proc procName(n: nT): returnT
proc fun(n: nT): returnT {.gensym.} = procBody
proc procName(n: nT): returnT =
if not cache.hasKey(n): cache[n] = fun(n)
return cache[n]
let
returnT = e.params()[0]
param = e.params()[1]
(n, nT) = (param[0], param[1])
getAst(memo(n, nT, returnT, e.name(), e.body()))
proc fib(n: int): int
proc fibA(n: int): int {.memoize.} =
if n < 2: n else: fib(n-1) + fib(n-2)
proc fibB(n: int): int {.memoize.} =
if n < 2: n else: fibA(n-1) + fib(n-2)
proc fib(n: int): int {.memoize.} =
if n < 2: n else: fibA(n-1) + fibB(n-2)
echo fib(66)
I was about to investigate how to make it work with forward declarations, but you beat me at it!
About the package: can I use wiffel as author name?
... can I use wiffel as author name?
Sure (It's the same wiffel as the one you have for the nim implementation in your kmeans github repository :-) )