jax
JAX backend for lowering symbolic expressions to executable functions.
This module implements the JAX lowering backend that converts symbolic expression AST nodes into JAX functions with automatic differentiation support. The lowering uses a visitor pattern where each expression type has a corresponding visitor method.
Architecture
The JAX lowerer follows a visitor pattern with centralized registration:
- Visitor Registration: The @visitor decorator registers handler functions for each expression type in the _JAX_VISITORS dictionary
- Dispatch: The dispatch() function looks up and calls the appropriate visitor based on the expression's type
- Recursive Lowering: Each visitor recursively lowers child expressions and composes JAX operations
- Standardized Signature: All lowered functions have signature (x, u, node, params) -> result for uniformity
Key Features
- Automatic Differentiation: Lowered functions can be differentiated using JAX's jacfwd/jacrev for computing Jacobians
- JIT Compilation: All functions are JAX-traceable and JIT-compatible
- Functional Closures: Each visitor returns a closure that captures necessary constants and child functions
- Broadcasting: Supports NumPy-style broadcasting through jnp operations
Lowered Function Signature
All lowered functions have a uniform signature::
f(x, u, node, params) -> result
Where:
- x: State vector (jnp.ndarray)
- u: Control vector (jnp.ndarray)
- node: Node index for time-varying behavior (scalar or array)
- params: Dictionary of parameter values (dict[str, Any])
- result: JAX array (scalar, vector, or matrix)
Example
Basic usage::
from openscvx.symbolic.lowerers.jax import JaxLowerer
import openscvx as ox
# Create symbolic expression
x = ox.State("x", shape=(3,))
u = ox.Control("u", shape=(2,))
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
# Lower to JAX
lowerer = JaxLowerer()
f = lowerer.lower(expr)
# Evaluate
import jax.numpy as jnp
x_val = jnp.array([1.0, 2.0, 3.0])
u_val = jnp.array([0.5, 0.5])
result = f(x_val, u_val, node=0, params={})
# Differentiate
from jax import jacfwd
df_dx = jacfwd(f, argnums=0)
gradient = df_dx(x_val, u_val, node=0, params={})
For Contributors
Adding Support for New Expression Types
To add support for a new symbolic expression type to JAX lowering:
-
Define the visitor method in JaxLowerer with the @visitor decorator::
@visitor(MyNewExpr) def _visit_my_new_expr(self, node: MyNewExpr): # Lower child expressions recursively operand_fn = self.lower(node.operand)
# Return a closure with signature (x, u, node, params) -> result return lambda x, u, node, params: jnp.my_operation( operand_fn(x, u, node, params) ) -
Key requirements:
- Use the @visitor(ExprType) decorator to register the handler
- Method name should be visit
(private, lowercase, snake_case) - Recursively lower all child expressions using self.lower()
- Return a closure with signature (x, u, node, params) -> jax_array
- Use jnp.* operations (not np.*) for JAX traceability
- Ensure the result is JAX-differentiable (avoid Python control flow)
-
Example patterns:
- Unary operation: Lower operand, apply jnp function
- Binary operation: Lower both operands, combine with jnp operation
- N-ary operation: Lower all operands, reduce or combine them
- Conditional logic: Use jax.lax.cond for branching (see _visit_ctcs)
-
Testing: Ensure your visitor works with:
- JAX JIT compilation: jax.jit(lowered_fn)
- Automatic differentiation: jax.jacfwd(lowered_fn, argnums=0)
- Vectorization: jax.vmap(lowered_fn)
See Also
- lower_to_jax(): Convenience wrapper in symbolic/lower.py
- CVXPyLowerer: Alternative backend for convex constraints
- dispatch(): Core dispatch function for visitor pattern
_JAX_VISITORS: Dict[Type[Expr], Callable] = {}
module-attribute
¶
Registry mapping expression types to their visitor functions.
JaxLowerer
¶
JAX backend for lowering symbolic expressions to executable functions.
This class implements the visitor pattern for converting symbolic expression AST nodes to JAX functions. Each expression type has a corresponding visitor method decorated with @visitor that handles the lowering logic.
The lowering process is recursive: each visitor lowers its child expressions first, then composes them into a JAX operation. All lowered functions have a standardized signature (x, u, node, params) -> result.
Example
Set up the JaxLowerer and lower an expression to a JAX function:
lowerer = JaxLowerer()
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
f = lowerer.lower(expr)
result = f(x_val, u_val, node=0, params={})
Note
The lowerer is stateless and can be reused for multiple expressions. All visitor methods are instance methods to maintain a clean interface, but they don't modify instance state.
Source code in openscvx/symbolic/lowerers/jax.py
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 | |
_visit_abs(node: Abs)
¶
_visit_add(node: Add)
¶
Lower addition to JAX function.
Recursively lowers all terms and composes them with element-wise addition. Supports broadcasting following NumPy/JAX rules.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Add
|
Add expression node with multiple terms |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> sum of all terms |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_concat(node: Concat)
¶
Lower concatenation to JAX function (concatenates along axis 0).
Source code in openscvx/symbolic/lowerers/jax.py
_visit_constant(node: Constant)
¶
Lower a constant value to a JAX function.
Captures the constant value and returns a function that always returns it. Scalar constants are squeezed to ensure they're true scalars, not (1,) arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Constant
|
Constant expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> constant_value |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_constraint(node: Constraint)
¶
Lower constraint to residual function.
Both equality (lhs == rhs) and inequality (lhs <= rhs) constraints are lowered to their residual form: lhs - rhs. The constraint is satisfied when the residual equals zero (equality) or is non-positive (inequality).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Constraint
|
Equality or Inequality constraint node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> lhs - rhs (constraint residual) |
Note
The returned residual is used in penalty methods and Lagrangian terms. For equality: residual should be 0 For inequality: residual should be <= 0
Source code in openscvx/symbolic/lowerers/jax.py
_visit_control(node: Control)
¶
Lower a control variable to a JAX function.
Extracts the appropriate slice from the unified control vector u using the slice assigned during unification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Control
|
Control expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> u[slice] |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the control has no slice assigned (unification not run) |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_cos(node: Cos)
¶
_visit_cross_node_constraint(node: CrossNodeConstraint)
¶
Lower CrossNodeConstraint to trajectory-level function.
CrossNodeConstraint wraps constraints that reference multiple trajectory nodes via NodeReference (e.g., rate limits like x.at(k) - x.at(k-1) <= r).
Unlike regular nodal constraints which have signature (x, u, node, params) and are vmapped across nodes, cross-node constraints operate on full trajectory arrays and return a scalar residual.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
CrossNodeConstraint
|
CrossNodeConstraint expression wrapping the inner constraint |
required |
Returns:
| Type | Description |
|---|---|
|
Function with signature (X, U, params) -> scalar residual - X: Full state trajectory, shape (N, n_x) - U: Full control trajectory, shape (N, n_u) - params: Dictionary of problem parameters - Returns: Scalar constraint residual (g <= 0 convention) |
Note
The inner constraint is lowered first (producing a function with the
standard (x, u, node, params) signature), then wrapped to provide the
trajectory-level (X, U, params) signature. The node parameter is
unused since NodeReference nodes have fixed indices baked in.
Example
For constraint: position.at(5) - position.at(4) <= max_step
The lowered function evaluates: X[5, pos_slice] - X[4, pos_slice] - max_step
And returns a scalar residual.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_ctcs(node: CTCS)
¶
Lower CTCS (Continuous-Time Constraint Satisfaction) to JAX function.
CTCS constraints use penalty methods to enforce constraints over continuous time intervals. The lowered function includes conditional logic to activate the penalty only within the specified node interval.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
CTCS
|
CTCS constraint node with penalty expression and optional node range |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, current_node, params) -> penalty value or 0 |
Note
Uses jax.lax.cond for JAX-traceable conditional evaluation. The penalty is active only when current_node is in [start_node, end_node). If no node range is specified, the penalty is always active.
See Also
- CTCS: The symbolic CTCS constraint class
- penalty functions: PositivePart, Huber, SmoothReLU
Source code in openscvx/symbolic/lowerers/jax.py
_visit_diag(node: Diag)
¶
Lower diagonal matrix construction to JAX function.
_visit_div(node: Div)
¶
Lower element-wise division to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_exp(node: Exp)
¶
Lower exponential function to JAX function.
_visit_hstack(node: Hstack)
¶
Lower horizontal stacking to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_huber(node)
¶
Lower Huber penalty function to JAX.
Huber penalty is quadratic for small values and linear for large values: - |x| <= delta: 0.5 * x^2 - |x| > delta: delta * (|x| - 0.5 * delta)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Huber expression node with delta parameter |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> Huber penalty |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_index(node: Index)
¶
Lower indexing/slicing operation to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_log(node: Log)
¶
_visit_logsumexp(node: LogSumExp)
¶
Lower log-sum-exp to JAX function.
Computes log(sum(exp(x_i))) for multiple operands, which is a smooth approximation to the maximum function. Uses JAX's numerically stable logsumexp implementation. Performs element-wise log-sum-exp with broadcasting support.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_matmul(node: MatMul)
¶
Lower matrix multiplication to JAX function using jnp.matmul.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_max(node: Max)
¶
Lower element-wise maximum to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_mul(node: Mul)
¶
Lower element-wise multiplication to JAX function (Hadamard product).
Source code in openscvx/symbolic/lowerers/jax.py
_visit_neg(node: Neg)
¶
Lower negation (unary minus) to JAX function.
_visit_nodal_constraint(node: NodalConstraint)
¶
Lower a NodalConstraint by lowering its underlying constraint.
NodalConstraint is a wrapper that specifies which nodes a constraint applies to. The lowering just unwraps and lowers the inner constraint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
NodalConstraint
|
NodalConstraint wrapper |
required |
Returns:
| Type | Description |
|---|---|
|
Function from lowering the wrapped constraint expression |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_node_reference(node: NodeReference)
¶
Lower NodeReference - extract value at a specific trajectory node.
NodeReference extracts a state/control value at a specific node from the full trajectory arrays. The node index is baked into the lowered function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
NodeReference
|
NodeReference expression with base and node_idx (integer) |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node_param, params) that extracts from trajectory - x, u: Full trajectories (N, n_x) and (N, n_u) - node_param: Unused (kept for signature compatibility) - params: Problem parameters |
Example
position.at(5) lowers to a function that extracts x[5, position_slice] position.at(k-1) where k=7 lowers to extract x[6, position_slice]
Source code in openscvx/symbolic/lowerers/jax.py
_visit_norm(node: Norm)
¶
Lower norm operation to JAX function.
Converts symbolic norm to jnp.linalg.norm with appropriate ord parameter. Handles string ord values like "inf", "-inf", "fro".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Norm
|
Norm expression node with ord attribute |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> norm of operand |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_or(node: Or)
¶
Lower STL disjunction (Or) to JAX using STLJax library.
Converts a symbolic Or constraint to an STLJax Or formula for handling disjunctive task specifications. Each operand becomes an STLJax predicate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Or
|
Or expression node with multiple operands |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> STL robustness value |
Note
Uses STLJax library for signal temporal logic evaluation. The returned function computes the robustness metric for the disjunction, which is positive when at least one operand is satisfied.
Example
Used for task specifications like "reach goal A OR goal B"::
goal_A = ox.Norm(x - target_A) <= 1.0
goal_B = ox.Norm(x - target_B) <= 1.0
task = ox.Or(goal_A, goal_B)
See Also
- stljax.formula.Or: Underlying STLJax implementation
- STL robustness: Quantitative measure of constraint satisfaction
Source code in openscvx/symbolic/lowerers/jax.py
_visit_parameter(node: Parameter)
¶
Lower a parameter to a JAX function.
Parameters are looked up by name in the params dictionary at evaluation time, allowing runtime parameter updates without recompilation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Parameter
|
Parameter expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> params[name] |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_pos(node)
¶
Lower positive part function to JAX.
Computes max(x, 0), used in penalty methods for inequality constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
PositivePart expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> max(operand, 0) |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_power(node: Power)
¶
Lower element-wise power (base**exponent) to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
_visit_qdcm(node: QDCM)
¶
Lower quaternion to direction cosine matrix (DCM) conversion.
Converts a unit quaternion [q0, q1, q2, q3] to a 3x3 rotation matrix. Used in 6-DOF spacecraft and robotics applications.
The quaternion is normalized before conversion to ensure a valid rotation matrix. The DCM is computed using the standard quaternion-to-DCM formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
QDCM
|
QDCM expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> 3x3 rotation matrix |
Note
Quaternion convention: [w, x, y, z] where w is the scalar part
Source code in openscvx/symbolic/lowerers/jax.py
_visit_sin(node: Sin)
¶
_visit_sqrt(node: Sqrt)
¶
_visit_square(node)
¶
Lower square function to JAX.
Computes x^2 element-wise. Used in quadratic penalty methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
Square expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> operand^2 |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_srelu(node)
¶
Lower smooth ReLU penalty function to JAX.
Smooth approximation to ReLU: sqrt(max(x, 0)^2 + c^2) - c Differentiable everywhere, approaches ReLU as c -> 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
SmoothReLU expression node with smoothing parameter c |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> smooth ReLU penalty |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_ssm(node: SSM)
¶
Lower skew-symmetric matrix for cross product (3x3).
Creates a 3x3 skew-symmetric matrix from a vector such that SSM(a) @ b = a x b (cross product).
The SSM is the matrix representation of the cross product operator, allowing cross products to be computed as matrix-vector multiplication.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
SSM
|
SSM expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> 3x3 skew-symmetric matrix |
Note
For vector w = [x, y, z], returns: [[ 0, -z, y], [ z, 0, -x], [-y, x, 0]]
Source code in openscvx/symbolic/lowerers/jax.py
_visit_ssmp(node: SSMP)
¶
Lower skew-symmetric matrix for quaternion dynamics (4x4).
Creates a 4x4 skew-symmetric matrix from angular velocity vector for quaternion kinematic propagation: q_dot = 0.5 * SSMP(omega) @ q
The SSMP matrix is used in quaternion kinematics to compute quaternion derivatives from angular velocity vectors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
SSMP
|
SSMP expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> 4x4 skew-symmetric matrix |
Note
For angular velocity w = [x, y, z], returns: [[0, -x, -y, -z], [x, 0, z, -y], [y, -z, 0, x], [z, y, -x, 0]]
Source code in openscvx/symbolic/lowerers/jax.py
_visit_stack(node: Stack)
¶
Lower vertical stacking to JAX function (stack along axis 0).
Source code in openscvx/symbolic/lowerers/jax.py
_visit_state(node: State)
¶
Lower a state variable to a JAX function.
Extracts the appropriate slice from the unified state vector x using the slice assigned during unification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
State
|
State expression node |
required |
Returns:
| Type | Description |
|---|---|
|
Function (x, u, node, params) -> x[slice] |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the state has no slice assigned (unification not run) |
Source code in openscvx/symbolic/lowerers/jax.py
_visit_sub(node: Sub)
¶
Lower subtraction to JAX function (element-wise left - right).
Source code in openscvx/symbolic/lowerers/jax.py
_visit_sum(node: Sum)
¶
Lower sum reduction to JAX function (sums all elements).
_visit_tan(node: Tan)
¶
_visit_transpose(node: Transpose)
¶
Lower matrix transpose to JAX function.
_visit_vstack(node: Vstack)
¶
Lower vertical stacking to JAX function.
Source code in openscvx/symbolic/lowerers/jax.py
lower(expr: Expr)
¶
Lower a symbolic expression to a JAX function.
Main entry point for lowering. Delegates to dispatch() which looks up the appropriate visitor method based on the expression type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Expr
|
Symbolic expression to lower (any Expr subclass) |
required |
Returns:
| Type | Description |
|---|---|
|
JAX function with signature (x, u, node, params) -> result |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If no visitor exists for the expression type |
ValueError
|
If the expression is malformed (e.g., State without slice) |
Example
Lower an expression to a JAX function:
lowerer = JaxLowerer()
x = ox.State("x", shape=(3,))
expr = ox.Norm(x)
f = lowerer.lower(expr)
# f is now callable
Source code in openscvx/symbolic/lowerers/jax.py
dispatch(lowerer: Any, expr: Expr)
¶
Dispatch an expression to its registered visitor function.
Looks up the visitor function for the expression's type and calls it. This is the core of the visitor pattern implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lowerer
|
Any
|
The JaxLowerer instance (provides context for visitor methods) |
required |
expr
|
Expr
|
The expression node to lower |
required |
Returns:
| Type | Description |
|---|---|
|
The result of calling the visitor function (typically a JAX callable) |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If no visitor is registered for the expression type |
Example
Dispatch an expression to lower it to a JAX function:
lowerer = JaxLowerer()
expr = Add(x, y)
fn = dispatch(lowerer, expr) # Calls visit_add
Source code in openscvx/symbolic/lowerers/jax.py
visitor(expr_cls: Type[Expr])
¶
Decorator to register a visitor function for an expression type.
This decorator registers a visitor method to handle a specific expression type during JAX lowering. The decorated function is stored in _JAX_VISITORS and will be called by dispatch() when lowering that expression type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr_cls
|
Type[Expr]
|
The Expr subclass this visitor handles (e.g., Add, Mul, Norm) |
required |
Returns:
| Type | Description |
|---|---|
|
Decorator function that registers the visitor and returns it unchanged |
Example
Register a visitor function for the Add expression:
@visitor(Add)
def _visit_add(self, node: Add):
# Lower addition to JAX
...
Note
Multiple expression types can share a visitor by stacking decorators::
@visitor(Equality)
@visitor(Inequality)
def _visit_constraint(self, node: Constraint):
# Handle both equality and inequality
...