Hi, I am working on macro that should generate a derivative calculation, for a funtion
proc f1(x, y: float): float {.noSideEffect.} =
generates
proc f1_aad(x,y: float): float
that computes derivative of f1 at point x,y.
The difficulty I am having that Nim has so many ways to present the same thing. Internally, I don't care whether function was called using command syntax, dot syntax or infix operator or something else. Just want to know what function was called and whether it has known derivative.
So I would like to simplify/normalize AST proc body representation before walking it. So all possibilities collapse to a single presentation, though I need all resolved symbol to stay. Is there a ready made solution available for AST simplification?
You might be interested in my autograd library: nim-rmad.
It automatically compute the gradient of any function with regards to any of its input. It uses reverse-mode auto-differentiation behind the scene.
let ctx = newContext[float32]()
let
a = ctx.variable(1)
b = ctx.variable(2)
c = ctx.variable(-3)
x = ctx.variable(-1)
y = ctx.variable(3)
proc forwardNeuron[T](a,b,c,x,y: T): T =
let
ax = a * x
by = b * y
axpby = ax + by
axpbypc = axpby + c
s = axpbypc.sigmoid()
return s
var s = forwardNeuron(a,b,c,x,y)
echo s.value() # 0.8807970285415649
let gradient = s.grad()
echo gradient.wrt(a) # -0.1049936264753342
echo gradient.wrt(b) # 0.3149808645248413
echo gradient.wrt(c) # 0.1049936264753342
echo gradient.wrt(x) # 0.1049936264753342
echo gradient.wrt(y) # 0.2099872529506683
Hi Mratsim,
I have seen nim-rmad, however in my case requirement is somewhat different. The derivative calculation should not slowdown or change the interface of original function. Hence macro approach that generates separate set of functions to compute derivatives.
The AST is pretty normalized out of the box. The trick is to use nnkCallKinds.
import macros
macro swapArgs(n: untyped): untyped =
if n.kind in nnkCallKinds:
let tmp = n[2]
n[2] = n[1]
n[1] = tmp
result = n
echo swapArgs(5 - 6)
echo swapArgs(`-`(5, 6))
Dot calls are special because they are ambiguous; a.f(x) can really mean "call field 'f' in 'a' with argument 'x' " instead of "call f with a, x" and so we cannot normalize it before your macro gets to see it.