|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Creating Custom Function Nodes\n", |
| 8 | + "\n", |
| 9 | + "In this tutorial we do a deep dive into `FunctionNode` objects and how users can create their own nodes for various functions. Users should already be familiar with the concepts covered in the [Sampling Parameters notebook](https://lightcurvelynx.readthedocs.io/en/latest/notebooks/sampling.html)." |
| 10 | + ] |
| 11 | + }, |
| 12 | + { |
| 13 | + "cell_type": "code", |
| 14 | + "execution_count": null, |
| 15 | + "metadata": {}, |
| 16 | + "outputs": [], |
| 17 | + "source": [ |
| 18 | + "from lightcurvelynx.base_models import FunctionNode\n", |
| 19 | + "from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc" |
| 20 | + ] |
| 21 | + }, |
| 22 | + { |
| 23 | + "cell_type": "markdown", |
| 24 | + "metadata": {}, |
| 25 | + "source": [ |
| 26 | + "## Function Node Overview\n", |
| 27 | + "\n", |
| 28 | + "Function nodes provide users the ability to wrap arbitrary computations during the parameter generation stage. The name `FunctionNode` is a bit of a misnomer as these nodes can wrap any Python `callable` object. For simplicity, we will use the term function throughout this notebook, but users should understand the more general behavior.\n", |
| 29 | + "\n", |
| 30 | + "The basic flow of the `FunctionNode` is wrapped in the base class's `compute` method:\n", |
| 31 | + "\n", |
| 32 | + " * Assemble the wrapped function's input values from the `GraphState` object (model parameters) and keyword arguments,\n", |
| 33 | + " * Call the wrapped function with the assembled input values,\n", |
| 34 | + " * Capture the function's output, and\n", |
| 35 | + " * Write those values to the `GraphState` object.\n", |
| 36 | + "\n", |
| 37 | + "By default each function node stores its result in a parameter called `function_node_result`. Since model parameters are indexed by a combination of node name and parameter name, it will often be the case that multiple nodes in the model will generate `function_node_result` values. As we will see later in this notebook, we can override the name of the outputs to be more user friendly.\n", |
| 38 | + "\n", |
| 39 | + "There are two ways to use the `FunctionNode` class: as a standalone wrapper or as a parent class.\n", |
| 40 | + "\n", |
| 41 | + "## FunctionNode as a Standalone Wrapper\n", |
| 42 | + "\n", |
| 43 | + "Users can wrap a function directly by passing the function and its arguments into the `FunctionNode` constructor. This wraps the provided function and uses the functions returned value as its output.\n", |
| 44 | + "\n", |
| 45 | + "As a concrete example, let's create a `FunctionNode` that computes wraps an existing function that computes y = m * x + b. We need to pass this function and values for each of its parameters to the constructor." |
| 46 | + ] |
| 47 | + }, |
| 48 | + { |
| 49 | + "cell_type": "code", |
| 50 | + "execution_count": null, |
| 51 | + "metadata": {}, |
| 52 | + "outputs": [], |
| 53 | + "source": [ |
| 54 | + "# This is the function we would like to wrap.\n", |
| 55 | + "def linear_eq_function(x, m, b):\n", |
| 56 | + " \"\"\"Compute y = m * x + b\"\"\"\n", |
| 57 | + " return m * x + b\n", |
| 58 | + "\n", |
| 59 | + "\n", |
| 60 | + "# This is how we wrap linear_eq_function.\n", |
| 61 | + "func_node = FunctionNode(\n", |
| 62 | + " linear_eq_function, # First argument is the function to call.\n", |
| 63 | + " # The function's parameters are given as keyword arguments to the FunctionNode.\n", |
| 64 | + " x=NumpyRandomFunc(\"uniform\", low=0.0, high=10.0), # Random value\n", |
| 65 | + " m=5.0, # Constant value\n", |
| 66 | + " b=-2.0, # Constant value\n", |
| 67 | + ")" |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "markdown", |
| 72 | + "metadata": {}, |
| 73 | + "source": [ |
| 74 | + "The first parameter of the function node is the function to evaluate, such as our linear equation above (`linear_eq_function`). Each input into that function **must** be included as a named parameter during the `FunctionNode` definition, such as `x`, `m`, and `b` above. If any of the input parameters are missing, the code will give an error. The `FunctionNode` class handles all the internal book keeping of: determining the names of the function's arguments, creating internal parameters, and assembling those arguments whenever the function is called.\n", |
| 75 | + "\n", |
| 76 | + "Here we provide constants for `m` and `b` so we use the same linear formulation for each sample. Only the value of `x` changes. However, we could have also used a whole tree of function nodes, including sampling functions, to set `m` and `b`. In that case it is important to remember that each of our results is a consistent sampling and computation over all the parameters in the model." |
| 77 | + ] |
| 78 | + }, |
| 79 | + { |
| 80 | + "cell_type": "code", |
| 81 | + "execution_count": null, |
| 82 | + "metadata": {}, |
| 83 | + "outputs": [], |
| 84 | + "source": [ |
| 85 | + "state = func_node.sample_parameters(num_samples=5)\n", |
| 86 | + "print(state)" |
| 87 | + ] |
| 88 | + }, |
| 89 | + { |
| 90 | + "cell_type": "markdown", |
| 91 | + "metadata": {}, |
| 92 | + "source": [ |
| 93 | + "As described above both of the nodes (the numpy sampler and the linear function) create `function_node_result` parameters to store their intermediate results.\n", |
| 94 | + "\n", |
| 95 | + "The nodes can be chained by using one `FunctionNode` as the value for a parameter of another. When the a `FunctionNode` is passed as a parameter, LightCurveLynx will automatically link that parameter to the `FunctionNode`'s `function_node_result` value. Below you can see that the input (`x`) of our increment function corresponds directly to the output (`function_node_result`) of the linear equation function." |
| 96 | + ] |
| 97 | + }, |
| 98 | + { |
| 99 | + "cell_type": "code", |
| 100 | + "execution_count": null, |
| 101 | + "metadata": {}, |
| 102 | + "outputs": [], |
| 103 | + "source": [ |
| 104 | + "def increment(x):\n", |
| 105 | + " \"\"\"Increment x by 1.\"\"\"\n", |
| 106 | + " return x + 1\n", |
| 107 | + "\n", |
| 108 | + "\n", |
| 109 | + "# This is how we wrap increment function.\n", |
| 110 | + "inc_node = FunctionNode(\n", |
| 111 | + " increment, # First argument is the function to call.\n", |
| 112 | + " # The function's parameters are given as keyword arguments to the FunctionNode.\n", |
| 113 | + " x=func_node, # Use the output of func_node as input to increment.\n", |
| 114 | + ")\n", |
| 115 | + "\n", |
| 116 | + "state = inc_node.sample_parameters(num_samples=5)\n", |
| 117 | + "print(state)" |
| 118 | + ] |
| 119 | + }, |
| 120 | + { |
| 121 | + "cell_type": "markdown", |
| 122 | + "metadata": {}, |
| 123 | + "source": [ |
| 124 | + "We could make the linking of parameters more explicit by using the dot notation and the parameter name. But the behavior is identical." |
| 125 | + ] |
| 126 | + }, |
| 127 | + { |
| 128 | + "cell_type": "code", |
| 129 | + "execution_count": null, |
| 130 | + "metadata": {}, |
| 131 | + "outputs": [], |
| 132 | + "source": [ |
| 133 | + "# This is how we wrap increment function.\n", |
| 134 | + "inc_node = FunctionNode(\n", |
| 135 | + " increment, # First argument is the function to call.\n", |
| 136 | + " # The function's parameters are given as keyword arguments to the FunctionNode.\n", |
| 137 | + " x=func_node.function_node_result, # named parameter\n", |
| 138 | + ")\n", |
| 139 | + "\n", |
| 140 | + "state = inc_node.sample_parameters(num_samples=5)\n", |
| 141 | + "print(state)" |
| 142 | + ] |
| 143 | + }, |
| 144 | + { |
| 145 | + "cell_type": "markdown", |
| 146 | + "metadata": {}, |
| 147 | + "source": [ |
| 148 | + "More realistically, users will want to wrap functions that perform complex astronomical calculations." |
| 149 | + ] |
| 150 | + }, |
| 151 | + { |
| 152 | + "cell_type": "markdown", |
| 153 | + "metadata": {}, |
| 154 | + "source": [ |
| 155 | + "## FunctionNode Subclasses\n", |
| 156 | + "\n", |
| 157 | + "In the case where users will want to create function nodes that carry around additional data, users can create subclasses of the `FunctionNode` class. For example, when computing the distmod from the redshift, we need to load the cosmology. While we could load the cosmology each time the function is called, it would be more efficient to load it once and reuse it across computations." |
| 158 | + ] |
| 159 | + }, |
| 160 | + { |
| 161 | + "cell_type": "code", |
| 162 | + "execution_count": null, |
| 163 | + "metadata": {}, |
| 164 | + "outputs": [], |
| 165 | + "source": [ |
| 166 | + "from astropy.cosmology import FlatLambdaCDM\n", |
| 167 | + "\n", |
| 168 | + "\n", |
| 169 | + "class DistModFromRedshift(FunctionNode):\n", |
| 170 | + " \"\"\"A wrapper class for the _distmod_from_redshift() function.\n", |
| 171 | + "\n", |
| 172 | + " Parameters\n", |
| 173 | + " ----------\n", |
| 174 | + " redshift : function or constant\n", |
| 175 | + " The function or constant providing the redshift value.\n", |
| 176 | + " H0 : constant\n", |
| 177 | + " The Hubble constant.\n", |
| 178 | + " Omega_m : constant\n", |
| 179 | + " The matter density Omega_m.\n", |
| 180 | + " **kwargs : dict, optional\n", |
| 181 | + " Any additional keyword arguments.\n", |
| 182 | + " \"\"\"\n", |
| 183 | + "\n", |
| 184 | + " def __init__(self, redshift, H0=73.0, Omega_m=0.3, **kwargs):\n", |
| 185 | + " # Create the cosmology once for this node. This is constructed ONCE for all samples.\n", |
| 186 | + " if not isinstance(H0, float) or not isinstance(Omega_m, float):\n", |
| 187 | + " raise ValueError(\"H0 and Omega_m must be constants.\")\n", |
| 188 | + " self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)\n", |
| 189 | + "\n", |
| 190 | + " # Call the super class's constructor with the needed information.\n", |
| 191 | + " super().__init__(\n", |
| 192 | + " func=self._distmod_from_redshift, # \"Function\" being wrapped\n", |
| 193 | + " redshift=redshift,\n", |
| 194 | + " **kwargs,\n", |
| 195 | + " )\n", |
| 196 | + "\n", |
| 197 | + " def _distmod_from_redshift(self, redshift):\n", |
| 198 | + " \"\"\"Compute distance modulus given redshift and cosmology.\n", |
| 199 | + "\n", |
| 200 | + " Parameters\n", |
| 201 | + " ----------\n", |
| 202 | + " redshift : float or numpy.ndarray\n", |
| 203 | + " The redshift value(s).\n", |
| 204 | + "\n", |
| 205 | + " Returns\n", |
| 206 | + " -------\n", |
| 207 | + " distmod : float or numpy.ndarray\n", |
| 208 | + " The distance modulus (in mag)\n", |
| 209 | + " \"\"\"\n", |
| 210 | + " return self.cosmo.distmod(redshift).value" |
| 211 | + ] |
| 212 | + }, |
| 213 | + { |
| 214 | + "cell_type": "markdown", |
| 215 | + "metadata": {}, |
| 216 | + "source": [ |
| 217 | + "There are a few things to note from the implementation above. \n", |
| 218 | + "\n", |
| 219 | + "First, since the cosmology is created on a per-object basis, it will be the same for every evaluation. Its parameters, `H0` and `Omega_m` are fixed for all samples. Only the input `redshift` is changing.\n", |
| 220 | + "\n", |
| 221 | + "Second, the \"function\" being wrapped by the function node is actually an object method. As we noted earlier, the `FunctionNode` can actually wrap any Python `callable` object. By wrapping an internal method, the computation has access to the object's attributes via `self`." |
| 222 | + ] |
| 223 | + }, |
| 224 | + { |
| 225 | + "cell_type": "markdown", |
| 226 | + "metadata": {}, |
| 227 | + "source": [ |
| 228 | + "## Supporting Multiple Outputs\n", |
| 229 | + "\n", |
| 230 | + "If the wrapped function produces multiple outputs, the user can assign names to each output via the `outputs` constructor argument. This argument takes a list of strings that is same length as the number of outputs produced. Each result is separately stored in a corresponding named parameter (instead of the default `function_node_result` parameter). These parameters are added automatically to the object." |
| 231 | + ] |
| 232 | + }, |
| 233 | + { |
| 234 | + "cell_type": "code", |
| 235 | + "execution_count": null, |
| 236 | + "metadata": {}, |
| 237 | + "outputs": [], |
| 238 | + "source": [ |
| 239 | + "# A function that returns two values.\n", |
| 240 | + "def _linear_pair(x, m, b):\n", |
| 241 | + " \"\"\"Compute y1 = m * x + b and y2 = -1/m * x - b\"\"\"\n", |
| 242 | + " return (m * x + b, -1.0 / m * x - b)\n", |
| 243 | + "\n", |
| 244 | + "\n", |
| 245 | + "# A function node that returns two values. The outputs are named \"y1\" and \"y2\".\n", |
| 246 | + "func_node2 = FunctionNode(\n", |
| 247 | + " _linear_pair, # First parameter is the function to call.\n", |
| 248 | + " x=NumpyRandomFunc(\"uniform\", low=0.0, high=10.0),\n", |
| 249 | + " m=5.0,\n", |
| 250 | + " b=-2.0,\n", |
| 251 | + " outputs=[\"y1\", \"y2\"], # The output names.\n", |
| 252 | + ")\n", |
| 253 | + "\n", |
| 254 | + "print(func_node2.sample_parameters(num_samples=5))" |
| 255 | + ] |
| 256 | + }, |
| 257 | + { |
| 258 | + "cell_type": "markdown", |
| 259 | + "metadata": {}, |
| 260 | + "source": [ |
| 261 | + "The outputs can be referenced individually using the dot notation with their given name. Below we reimplement the increment function using just the `y2` output as the function's input." |
| 262 | + ] |
| 263 | + }, |
| 264 | + { |
| 265 | + "cell_type": "code", |
| 266 | + "execution_count": null, |
| 267 | + "metadata": {}, |
| 268 | + "outputs": [], |
| 269 | + "source": [ |
| 270 | + "# This is how we wrap increment function.\n", |
| 271 | + "inc_node2 = FunctionNode(\n", |
| 272 | + " increment, # First argument is the function to call.\n", |
| 273 | + " # The function's parameters are given as keyword arguments to the FunctionNode.\n", |
| 274 | + " x=func_node2.y2, # Use the named output.\n", |
| 275 | + ")\n", |
| 276 | + "\n", |
| 277 | + "print(inc_node2.sample_parameters(num_samples=10))" |
| 278 | + ] |
| 279 | + }, |
| 280 | + { |
| 281 | + "cell_type": "markdown", |
| 282 | + "metadata": {}, |
| 283 | + "source": [ |
| 284 | + "The named output is most often used in nodes that produce a combination of correlated values, such as (RA, Dec). See the [Sampling Object Positions notebook](https://lightcurvelynx.readthedocs.io/en/latest/notebooks/sampling_positions.html) for examples." |
| 285 | + ] |
| 286 | + }, |
| 287 | + { |
| 288 | + "cell_type": "markdown", |
| 289 | + "metadata": {}, |
| 290 | + "source": [ |
| 291 | + "## Randomization\n", |
| 292 | + "\n", |
| 293 | + "Care must be taken when creating new function nodes that use randomization. To be consistent, users will want the nodes to be completely random by default, but have the ability to use a provided random number generator. The difficulty is that `FunctionNode.compute()` does **not** pass along the random number generator to the function. It can't because not all wrapped functions can even take a random number generator parameter.\n", |
| 294 | + "\n", |
| 295 | + "Instead there are two supported approaches to enable random behavior.\n", |
| 296 | + "\n", |
| 297 | + "### Use Random Parameters (RECOMMENDED)\n", |
| 298 | + "\n", |
| 299 | + "Users can add new parameters in the their class's constructor that correspond to the random values they would like to generate. For example, if we wanted to implement a *noisy* linear function: y = m * x + b, we could add a noise parameter. We set this parameter using a `NumpyRandomFunc` or other random node. This approach takes care of the internal bookkeeping " |
| 300 | + ] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": null, |
| 305 | + "metadata": {}, |
| 306 | + "outputs": [], |
| 307 | + "source": [ |
| 308 | + "class NoisyLinear(FunctionNode):\n", |
| 309 | + " \"\"\"A noisy linear function node.\"\"\"\n", |
| 310 | + "\n", |
| 311 | + " def __init__(self, x, m, b, **kwargs):\n", |
| 312 | + " # Create the noise function once that will be constructed once, but queried for each sample.\n", |
| 313 | + " self.noise_func = NumpyRandomFunc(\"normal\", loc=0.0, scale=1.0)\n", |
| 314 | + "\n", |
| 315 | + " # Call the super class's constructor with the needed information.\n", |
| 316 | + " super().__init__(\n", |
| 317 | + " func=self._noisy_linear_eq, # \"Function\" being wrapped\n", |
| 318 | + " x=x,\n", |
| 319 | + " m=m,\n", |
| 320 | + " b=b,\n", |
| 321 | + " noise=self.noise_func,\n", |
| 322 | + " **kwargs,\n", |
| 323 | + " )\n", |
| 324 | + "\n", |
| 325 | + " def _noisy_linear_eq(self, x, m, b, noise):\n", |
| 326 | + " \"\"\"Compute y = m * x + b + noise.\"\"\"\n", |
| 327 | + " return m * x + b + noise\n", |
| 328 | + "\n", |
| 329 | + "\n", |
| 330 | + "my_node = NoisyLinear(\n", |
| 331 | + " x=NumpyRandomFunc(\"uniform\", low=0.0, high=10.0),\n", |
| 332 | + " m=5.0,\n", |
| 333 | + " b=-2.0,\n", |
| 334 | + ")\n", |
| 335 | + "print(my_node.sample_parameters(num_samples=5))" |
| 336 | + ] |
| 337 | + }, |
| 338 | + { |
| 339 | + "cell_type": "markdown", |
| 340 | + "metadata": {}, |
| 341 | + "source": [ |
| 342 | + "As you can see, the noise parameter is sampled first and applied as though it was any other constant." |
| 343 | + ] |
| 344 | + }, |
| 345 | + { |
| 346 | + "cell_type": "markdown", |
| 347 | + "metadata": {}, |
| 348 | + "source": [ |
| 349 | + "### Custom Compute Function\n", |
| 350 | + "\n", |
| 351 | + "If users need more control over how the randomness is used, they can override the `compute` function which does take a random number generator. However, the `compute` function contains other logic that will need to be replicated, including the assembly of the functions parameters and writing the results to the `GraphState`. We recommend this approach only for experienced users. For examples of this approach, see the code for the `NumpyRandomFunc` class itself." |
| 352 | + ] |
| 353 | + } |
| 354 | + ], |
| 355 | + "metadata": { |
| 356 | + "kernelspec": { |
| 357 | + "display_name": "lightcurvelynx (3.13.8)", |
| 358 | + "language": "python", |
| 359 | + "name": "python3" |
| 360 | + }, |
| 361 | + "language_info": { |
| 362 | + "codemirror_mode": { |
| 363 | + "name": "ipython", |
| 364 | + "version": 3 |
| 365 | + }, |
| 366 | + "file_extension": ".py", |
| 367 | + "mimetype": "text/x-python", |
| 368 | + "name": "python", |
| 369 | + "nbconvert_exporter": "python", |
| 370 | + "pygments_lexer": "ipython3", |
| 371 | + "version": "3.13.8" |
| 372 | + } |
| 373 | + }, |
| 374 | + "nbformat": 4, |
| 375 | + "nbformat_minor": 2 |
| 376 | +} |
0 commit comments