A small automatic differentiation (AD) library implemented in C++. This library uses C++ operator overloading to automatically construct a static computation graph from mathematical expressions in source code. This graph is evaluated in the forward pass, and the gradient is computed in the backward pass using reverse accumulation AD. Detailed API reference is available on GitHub Pages.
Installation
To build, test, and install this library, execute the following commands:
git clone git@github.com:mattrrubino/ezdiff.git
cd ezdiff
mkdir build && cd build
cmake ..
cmake --build .
ctest
sudo cmake --install .
Usage
To use this library, you must construct expressions using the Variable type. These variables are assigned a value, and you use them to compute derivatives. For example, suppose you wanted to compute the derivative of the following function:
You could compute the derivative of this function with respect to $x$ where $x = 3$ using the following code:
#include <iostream>
#include <ezd/ezd.h>
int main() {
Variable x = ezd::make_var(3.0f, 0.0f);
Variable f = x * x + 2.0f * x + 1.0f;
f->forward();
f->zero_grad();
f->backward();
std::cout << x->get_grad() << std::endl;
}
For more examples, see the test directory. In particular, the linear regression test shows how this library can be applied to machine learning.
Performance
This implementation prioritizes simplicity. It is readable and correct, but it is not performant. Specifically, it suffers from the following issues:
- Heap Allocations\ Each node is stored using a separate heap allocation. Because computational nodes are spread throughout the heap, forward and backward passes suffer from poor cache utilization. To rememdy this, the nodes should be stored in a contiguous region of memory so that the forward and backward passes are cache-friendly.
- Recursion\ The forward and backward passes are implemented recursively. This adds significant function call overhead. To remedy this, the passes should be implemented iteratively.
- Polymorphism\ The behavior of the forward and backward passes are manipulated polymorphically. This adds virtual function overhead. To remedy this, the passes could be implemented with a switch.
- Smart Pointers\ Memory is managed using shared pointers. This adds overhead for reference counting. To remedy this, standard pointers could be used with manual freeing.
- Sequential\ This implementation is sequential. It does not leverage GPUs or SIMD instructions. Parallelizing the code would make it faster for large graphs.