Source code for physika.utils.ast_utils

from __future__ import annotations

import re
from typing import Any, Callable, Literal, Union, cast
from physika.utils.print_utils import print_unified_ast
from physika.elf import REGISTRY
# AST TYPE DEFINITIONS
# The parser produces a tree of tagged tuples.  Every non-leaf node is a
# tuple whose first element is a string (tag) and whose remaining elements are
# other nodes, lists of nodes, or scalar leaves.
#
# The type aliases below make the tag vocabulary explicit so that mypy
# can flag typos, missing branches, and wrong argument types.

# Tag literals (every valid first element of an AST tuple)

ExprTag = Literal[
    "add",
    "sub",
    "mul",
    "div",
    "pow",
    "matmul",  # binary arithmetic
    "neg",  # unary arithmetic
    "num",
    "var",  # literals / references
    "string",
    "equation_string",  # string literals
    "array",  # [elem, ...]
    "index",
    "slice",  # arr[i], arr[a:b]
    "call",
    "call_index",  # f(x), f(x)[i]
    "imaginary",  # complex unit  i
]

StmtTag = Literal[
    "decl",  # x : R = expr          -> (tag, name, type, expr, lineno)
    "assign",  # x = expr              -> (tag, name, expr, lineno)
    "expr",  # expr                  -> (tag, expr, lineno)
    "func_def",  # def f(...)            -> (tag, name)
    "class_def",  # class C(...)          -> (tag, name)
    "for_loop",  # for i: ...          -> (tag, var, [body], [arrays], lineno)
    "for_loop_range",  # for i: ℕ(n) / ℕ(s,e): -> (tag, var, start, end,
    #                                                 [body], lineno)
]

BodyStmtTag = Literal[
    "body_assign",  # x = expr          inside function body
    "body_decl",  # x : T = expr      inside function body
    "body_tuple_unpack",  # a, b = expr       inside function body
    "body_for",  # for k: ...        for-loop inside function body
    "body_zeros_decl",  # C : ℝ[n,o]        type annotation for an accumulation
    #                                           target
    "body_for_map",
    "body_for_accum",  # for i j k: ...    accumulation loop. emits
    #                                          torch.stack
    #                                          per target
    "loop_assign",  # x = expr          inside for-loop body
    "loop_pluseq",  # x += expr         inside for-loop body
    "loop_index_pluseq",  # C[i,...] += expr  nD accumulation inside for-loop
    #                                             body
    "for_assign",  # x = expr          program-level for body
    "for_pluseq",  # x += expr         program-level for body
    "for_call",  # f(x)              program-level for body
]

TypeTag = Literal[
    "func_type",  # R -> R                -> (tag, input_type, output_type)
    "tangent",  # T_x M                -> (tag, point_id, manifold_type)
    "tensor",  # R[3,3]               -> (tag, [(dim, variance), ...])
]

ASTTag = Union[ExprTag, StmtTag, BodyStmtTag, TypeTag]

# Composite node type
ASTNode = Union[
    tuple[Any, ...],  # tagged nodes: ("add", left, right), ("num", 1.0), ...
    list["ASTNode"],  # child sequences: function args, loop bodies, ...
    str,  # identifiers, string literal values
    int,  # integer values (line numbers, dimension sizes)
    float,  # numeric literal values
    None,  # empty / no-op placeholder
]


[docs] def ast_uses_solve(node: ASTNode) -> bool: """Check whether an AST subtree contains a call to ``solve``. Recursively walks *node* looking for ``("call", "solve", ...)``. Parameters ---------- node : ASTNode A tagged tuple, list, or scalar leaf of an AST. Returns ------- bool ``True`` if a ``("call", "solve", ...)`` node exists anywhere in the subtree, ``False`` otherwise. Examples -------- >>> from physika.utils.ast_utils import ast_uses_solve >>> ast_uses_solve(("call", "solve", [("var", "eq1"), ("var", "eq2")])) True >>> ast_uses_solve(("add", ("num", 1.0), ("var", "x"))) False """ if not isinstance(node, (tuple, list)): return False if isinstance(node, tuple) and len(node) >= 2: if node[0] == "call" and node[1] == "solve": return True return any( ast_uses_solve(child) for child in node[1:] if isinstance(child, (tuple, list))) if isinstance(node, list): return any(ast_uses_solve(item) for item in node) return False
[docs] def ast_uses_func(node: ASTNode, func_name: str) -> bool: """Check whether an AST subtree contains a call to *func_name*. Recursively walks *node* looking for both ``("call", func_name, ...)`` and ``("call_index", func_name, ..., idx)`` nodes. Used during calling of ``from_ast_to_torch`` to decide which runtime helpers need to be imported. Parameters ---------- node : ASTNode A tagged tuple, list, or scalar leaf of an AST. func_name : str The function identifier to search for (e.g. ``"train"``, ``"grad"``, ``"simulate"``). Returns ------- bool ``True`` if a matching call node exists anywhere in the subtree, ``False`` otherwise. Examples -------- >>> from physika.utils.ast_utils import ast_uses_func >>> ast_uses_func(("call", "train", [("var", "model")]), "train") True >>> ast_uses_func(("call_index", "grad", [("var", "H")], ("num", 0.0)), "grad") # noqa: E501 True >>> ast_uses_func(("add", ("num", 1.0), ("num", 2.0)), "train") False """ if not isinstance(node, (tuple, list)): return False if isinstance(node, tuple) and len(node) >= 2: if node[0] == "call" and node[1] == func_name: return True if node[0] == "call_index" and node[1] == func_name: return True return any( ast_uses_func(child, func_name) for child in node[1:] if isinstance(child, (tuple, list))) if isinstance(node, list): return any(ast_uses_func(item, func_name) for item in node) return False
[docs] def ast_uses_sympy(node: ASTNode) -> bool: """Check whether an AST subtree contains a Symbol or Function declaration. Recursively walks *node* looking for ``("symbol_decl", ...)``` or ``("function_decl", ...)`` nodes, which indicate that sympy is needed as a backed for symbolic math Parameters ---------- node : ASTNode A tagged tuple, list, or scalar leaf of an AST. Returns ------- bool ``True`` if a ``("symbol_decl", ...)``` or ```("function_decl", ...)`` node exists anywhere in the subtree, ``False`` otherwise. Examples -------- >>> from physika.utils.ast_utils import ast_uses_sympy >>> ast_uses_sympy(("symbol_decl", "x")) True >>> ast_uses_sympy(("function_decl", "u")) True >>> ast_uses_sympy(("num", "1.0")) False """ if not isinstance(node, (tuple, list)): return False if isinstance(node, tuple) and len(node) >= 1: if node[0] in ("symbol_decl", "symbol_decl_multi", "function_decl", "function_decl_multi"): return True return any( ast_uses_sympy(child) for child in node[1:] if isinstance(child, (tuple, list))) if isinstance(node, list): return any(ast_uses_sympy(item) for item in node) return False
[docs] def collect_grad_targets(node: ASTNode, targets: set[str]) -> None: """ Collect variable names used as differentiation targets in ``grad()`` calls. Recursively walks *node* looking for ``("call", "grad", [output, input])`` patterns and extracts the second argument (the differentiation variable) when it is a ``("var", name)`` node. The collected names are added to *targets* so that ``generate_statement`` can initialise those variables with ``requires_grad=True``. Parameters ---------- node : ASTNode A tagged tuple, list, or scalar leaf of an AST. targets : set[str] Mutable set to add target variable names into. Modified in place; not returned. Examples -------- >>> from physika.utils.ast_utils import collect_grad_targets >>> targets = set() >>> stmt = ("expr", ("call", "grad", [("var", "H"), ("var", "t")])) >>> collect_grad_targets(stmt, targets) >>> targets {'t'} """ if not isinstance(node, (tuple, list)): return if isinstance(node, tuple) and len(node) >= 2: if node[0] in ("call", "call_index") and node[1] == "grad" and len(node) >= 3: args = node[2] if len(args) >= 2 and isinstance(args[1], tuple) and args[1][0] == "var": targets.add(args[1][1]) for child in node[1:]: if isinstance(child, (tuple, list)): collect_grad_targets(child, targets) elif isinstance(node, list): for item in node: collect_grad_targets(item, targets)
[docs] def replace_class_params(code: str, class_params: list[tuple[str, ASTNode]]) -> str: """ Replace class parameter references with ``self.param`` in generated code. Rewrites bare parameter names inside the generated ``forward`` and ``loss`` method bodies. Applies regex substitutions for three contexts: function calls (``f(`` -> ``self.f(``), array indexing (``W[`` -> ``self.W[``), and standalone references inside parenthesised expressions. Parameters ---------- code : str The generated Python source string to transform. class_params : list[tuple[str, ASTNode]] List of ``(name, type_spec)`` pairs from the class definition. Only the names are used; type specs are ignored. Returns ------- str A new string with class parameter names prefixed by ``self.`` in the appropriate syntactic contexts. Examples -------- >>> from physika.utils.ast_utils import replace_class_params >>> replace_class_params("(W @ x + b)", [("W", "ℝ"), ("b", "ℝ")]) '(self.W @ x + self.b)' """ for cp_name, _ in class_params: # Replace function calls: f(...) -> self.f(...) code = re.sub(rf'\b{cp_name}\(', f'self.{cp_name}(', code) # Replace array indexing: W[...] -> self.W[...] code = re.sub(rf'\b{cp_name}\[', f'self.{cp_name}[', code) # Replace standalone references in expressions code = re.sub(rf'\(({cp_name})\s', r'(self.\1 ', code) code = re.sub(rf'\s({cp_name})\)', r' self.\1)', code) code = re.sub(rf'\(({cp_name})\)', r'(self.\1)', code) # Catch remaining word references not already prefixed code = re.sub(rf'(?<!self\.)\b{cp_name}\b', f'self.{cp_name}', code) return code
def _is_loop_var(expr: ASTNode, var: str) -> bool: """Return True if `expr` represents the loop variable. Handles both the ``("var", name)`` form and the special ``("imaginary",)`` form, which is used when the loop variable is named ``"i"`` (since the lexer emits ``IMAGINARY`` for the token ``i``). Parameters ---------- expr : ASTNode An AST expression node to test. var : str The loop variable name to match against. Returns ------- bool ``True`` if *expr* refers to the loop variable *var*. Examples -------- from physika.utils.ast_utils import _is_loop_var >>> _is_loop_var(("var", "k"), "k") True >>> _is_loop_var(("imaginary",), "i") True >>> _is_loop_var(("var", "j"), "k") False """ return ((isinstance(expr, tuple) and expr[0] == "var" and expr[1] == var) or (var == "i" and isinstance(expr, tuple) and expr[0] == "imaginary")) def _decompose_chain(expr: ASTNode) -> tuple[str | None, list[ASTNode]]: """Decompose a chain-index or 1-D index node into (array_name, [idx_expr, ...]). Recursively walks left-associative ``("chain_index", base, idx)`` nodes back to the underlying array name and collects all index expressions in order. The base case is 1-D indexing ``("index", arr, idx)``. Parameters ---------- expr : ASTNode A ``("chain_index", ...)`` or ``("index", ...)`` node, or any other node (returns ``(None, [])`` for unrecognised shapes). Returns ------- array_name : str or None The name of the array being indexed, or ``None`` if the expression is not a recognised index form. idx_exprs : list of ASTNode Index expressions in outermost-to-innermost order, matching the dimension order of the underlying array. Examples -------- >>> from physika.utils.ast_utils import _decompose_chain >>> _decompose_chain(("index", "A", ("var", "i"))) ('A', [('var', 'i')]) >>> _decompose_chain(("chain_index", ("index", "A", ("var", "i")), ("var", "k"))) # noqa: E501 ('A', [('var', 'i'), ('var', 'k')]) """ if not isinstance(expr, tuple): return None, [] if expr[0] == "index": _, arr, idx = expr if isinstance(arr, str): return arr, [idx] return None, [] if expr[0] == "chain_index": base_name, base_idxs = _decompose_chain(expr[1]) return base_name, base_idxs + [expr[2]] return None, [] def _infer_range(var: str, expr: ASTNode, skip: str) -> str | None: """Walk an AST expression and return a ``shape`` string for *var*. Searches the expression tree for array-index nodes where *var* appears as a subscript, then returns the corresponding shape as string. The array named *skip* is excluded from the search as it is the accumulation target being defined. Handles ``("indexN", arr, [idx, ...])``, ``("index", arr, idx)`` (1-D indexing), and ``("chain_index", ...)`` (chained bracket indexing A[i][k]). Parameters ---------- var : str The loop variable whose range we want to determine. expr : ASTNode The RHS AST expression to search. skip : str Name of the tensor being accumulated into. Returns ------- str or None A Python expression such as ``"A.shape[0]"`` giving the loop range, or ``None`` if no suitable index access was found. Examples -------- >>> from physika.utils.ast_utils import _infer_range >>> rhs = ("indexN", "A", [("var", "i"), ("var", "k")]) >>> _infer_range("i", rhs, "C") 'A.shape[0]' >>> _infer_range("k", rhs, "C") 'A.shape[1]' """ if not isinstance(expr, tuple): return None op = expr[0] if op == "indexN": arr = expr[1] if arr != skip: for dim, ie in enumerate(expr[2]): if _is_loop_var(ie, var): return f"{arr}.shape[{dim}]" elif op == "index": _, arr, ie = expr if isinstance(arr, str) and arr != skip: if _is_loop_var(ie, var): return f"{arr}.shape[0]" elif op == "chain_index": base_name, idx_exprs = _decompose_chain(expr) if base_name and base_name != skip: for dim, ie in enumerate(idx_exprs): if _is_loop_var(ie, var): return f"{base_name}.shape[{dim}]" for child in expr[1:]: if isinstance(child, tuple): r = _infer_range(var, child, skip) if r is not None: return r return None def _lhs_var_name(expr: ASTNode) -> str | None: """Extract the loop-variable name from an LHS index expression. Used to classify which loop variables appear as output dimensions (LHS indices of ``T[i, j]``). Parameters ---------- expr : ASTNode An index expression from the LHS of a ``loop_index_pluseq`` node. Returns ------- str or None The variable name, or ``None`` if the expression is not a plain variable reference. Examples -------- >>> from physika.utils.ast_utils import _lhs_var_name >>> _lhs_var_name(("var", "j")) 'j' >>> _lhs_var_name(("imaginary",)) 'i' >>> _lhs_var_name(("num", 0.0)) """ if isinstance(expr, tuple) and expr[0] == "var": return expr[1] if isinstance(expr, tuple) and expr[0] == "imaginary": return "i" return None def _has_complex(node: ASTNode) -> bool: """ Recursively determine whether an AST subtree contains a complex literal. Parameters ---------- node : ASTNode AST subtree. Returns ------- bool True if subtree contains complex number else False. Examples -------- >>> from physika.utils.ast_utils import _has_complex >>> _has_complex(("array", [("num", 1), ("num", 5)])) False >>> _has_complex(("array", [("complex", 3j), ("complex", 2j)])) True >>> _has_complex(("array", [('add', ('num', 1), ('complex', 3j,),), ('add', ('num', 2), ('complex', 5j,),)])) # noqa True """ if (isinstance(node, tuple) and len(node) >= 2 and node[0] == "complex" and isinstance(node[1], complex)): return True # recursively traverse tuples/lists (nested arrays) if isinstance(node, (tuple, list)): return any(_has_complex(child) for child in node) return False
[docs] def ast_to_torch_expr(node: ASTNode, indent: int = 0, current_loop_var: str | set[str] | None = None) -> str: """Convert an AST expression node to a PyTorch source code string. Recursively translates a Physika AST subtree into a valid Python/PyTorch expression string. Handles arithmetic operators, array construction, indexing, slicing, function calls (mapping builtins like ``sin`` to ``torch.sin``), and complex numbers. This is the core of the string-codegen path used by ``generate_function``, ``generate_class``, and ``generate_statement``. Parameters ---------- node : ASTNode AST expression node (tagged tuple) or a scalar leaf. indent : int, default 0 Current indentation level. Reserved for future use. current_loop_var : str or None, default None When set, an ``("imaginary",)`` node whose loop variable is ``"i"`` will emit the loop variable name instead of ``torch.tensor(1j)``, disambiguating the complex unit from the loop index. Returns ------- str A torch Python expression string corresponding to the given ASTNode. Examples -------- >>> from physika.utils.ast_utils import ast_to_torch_expr >>> ast_to_torch_expr(("add", ("num", 1.0), ("var", "x"))) '(1.0 + x)' >>> expected = ( ... "torch.sin(theta if isinstance(theta, torch.Tensor) " ... "else torch.tensor(float(theta)))" ... ) >>> ast_to_torch_expr(("call", "sin", [("var", "theta")])) == expected True >>> ast_to_torch_expr(("array", [("num", 1.0), ("num", 2.0)])) 'torch.tensor([1.0, 2.0])' """ if not isinstance(node, tuple): return repr(node) op = node[0] call_fallback = None if op == "num": val = node[1] if isinstance(val, float) and val == int(val): return f"{val}" return repr(val) elif op == "complex": val = node[1] return repr(val) elif op == "var": return node[1] elif op == "add": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} + {right})" elif op == "sub": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} - {right})" elif op == "mul": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} * {right})" elif op == "div": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} / {right})" elif op == "matmul": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} @ {right})" elif op == "pow": left = ast_to_torch_expr(node[1], indent, current_loop_var) right = ast_to_torch_expr(node[2], indent, current_loop_var) return f"({left} ** {right})" elif op == "neg": val = ast_to_torch_expr(node[1], indent, current_loop_var) return f"(-{val})" elif op == "array": elements = node[1] # Check if this is a nested array (contains other arrays) has_nested = any( isinstance(e, tuple) and e[0] == "array" for e in elements) contains_complex = any(_has_complex(e) for e in elements) if has_nested: # For nested arrays, generate list-of-lists and wrap in # torch.tensor def array_to_list(arr_node): if isinstance(arr_node, tuple) and arr_node[0] == "array": inner = [array_to_list(e) for e in arr_node[1]] return f"[{', '.join(inner)}]" else: return ast_to_torch_expr(arr_node, indent, current_loop_var) inner_lists = [array_to_list(e) for e in elements] if contains_complex: return f"torch.tensor([{', '.join(inner_lists)}], dtype=torch.complex64)" # noqa return f"torch.tensor([{', '.join(inner_lists)}])" else: all_numeric = all( isinstance(e, tuple) and ( e[0] == "num" or e[0] == "complex" or (e[0] == "neg" and isinstance(e[1], tuple) and e[1][0] == "num")) for e in elements) elem_strs = [ ast_to_torch_expr(e, indent, current_loop_var) for e in elements ] if all_numeric: if contains_complex: return f"torch.tensor([{', '.join(elem_strs)}], dtype=torch.complex64)" # noqa return f"torch.tensor([{', '.join(elem_strs)}])" else: if contains_complex: wrapped = [ f"torch.as_tensor({s}, dtype=torch.complex64)" for s in elem_strs ] return f"torch.stack([{', '.join(wrapped)}])" else: # Elements may be tensors (e.g., x[1], sin(x[0])) # use torch.stack wrapped = [f"torch.as_tensor({s})" for s in elem_strs] return f"torch.stack([{', '.join(wrapped)}])" elif op == "index": var_name = node[1] idx = ast_to_torch_expr(node[2], indent, current_loop_var) return f"{var_name}[int({idx})]" elif op == "slice": var_name = node[1] start = ast_to_torch_expr(node[2], indent, current_loop_var) end = ast_to_torch_expr(node[3], indent, current_loop_var) # Convert to int if needed start_int = f"int({start})" if "." in start else start end_int = f"int({end})" if "." in end else end return f"{var_name}[{start_int}:{end_int}]" elif op == "chain_index": obj = ast_to_torch_expr(node[1], indent, current_loop_var) idx = ast_to_torch_expr(node[2], indent, current_loop_var) return f"{obj}[int({idx})]" elif op == "indexN": arr = node[1] idx_codes = [ f"int({ast_to_torch_expr(e, indent, current_loop_var)})" for e in node[2] ] return f"{arr}[{', '.join(idx_codes)}]" elif op == "call": func_name = node[1] args = node[2] arg_strs = [ ast_to_torch_expr(arg, indent, current_loop_var) for arg in args ] arg = arg_strs[0] # Map built-in functions to PyTorch equivalents torch_funcs = { "exp": "torch.exp", "log": "torch.log", "sin": "torch.sin", "cos": "torch.cos", "sqrt": "torch.sqrt", "abs": "torch.abs", "sum": "torch.sum", "mean": "torch.mean", "real": "torch.real", } multi_arg_funcs = { "roll": "torch.roll", } if func_name in torch_funcs: return f"{torch_funcs[func_name]}({arg} if isinstance({arg}, torch.Tensor) else torch.tensor(float({arg})))" # noqa: E501 elif func_name in multi_arg_funcs: return f"{multi_arg_funcs[func_name]}({', '.join(arg_strs)})" elif func_name == "grad": # grad(output, input) -> compute_grad(output, input) inner = args[0] diff_var_node = args[1] diff_var_name = (diff_var_node[1] if isinstance(diff_var_node, tuple) and diff_var_node[0] == "var" else None) if isinstance(inner, tuple) and inner[0] == "call" and diff_var_name: # wrap in lambda so compute_grad evaluates on a fresh leaf inner_func = inner[1] inner_call_args = inner[2] lamb_var = f"_d{diff_var_name}" new_arg_strs = [ lamb_var if a == ("var", diff_var_name) else ast_to_torch_expr(a, indent, current_loop_var) for a in inner_call_args ] inner_call = f"{inner_func}({', '.join(new_arg_strs)})" return f"compute_grad(lambda {lamb_var}: {inner_call}, {arg_strs[1]})" # noqa return f"compute_grad({', '.join(arg_strs)})" elif func_name == "subs": # Substitution is used in mathematical expression which replace # all instances of something in expression with something else expr_code = arg_strs[0] sub_pairs = ", ".join(f"({arg_strs[i]}, {arg_strs[i+1]})" for i in range(1, len(arg_strs) - 1, 2)) return f"{expr_code}.subs([{sub_pairs}])" elif func_name == "diff": # diff is used to find derivative of expression w.r.t to variable if len(arg_strs) == 3: expr, var, order = arg_strs order = str(int(float(order))) return f"sp.diff({expr}, {var}, {order})" return f"sp.diff({', '.join(arg_strs)})" elif func_name == "lambdify": # lambdify converts sympy expression into expression which can be # evaluate numerically args0 = args[0] if isinstance(args0, tuple) and args0[0] == "array": sym_names = [e[1] for e in args0[1]] vars_code = "[" + ", ".join(sym_names) + "]" else: vars_code = arg_strs[0] expr_code = arg_strs[1] return (f"sp.lambdify({vars_code}, {expr_code}, " f"modules={torch_funcs})") elif func_name == "symbolic_solve": # symbolic_solve is wrapper for solve which finds solution of an # equation or system of equations, or the roots of function return f"sp.solve({', '.join(arg_strs)})" else: op = f"call:{func_name}" call_fallback = f"{func_name}({', '.join(arg_strs)})" elif op == "call_index": # Indexed function call: func(args)[index] func_name = node[1] args = node[2] index_ast = node[3] arg_strs = [ ast_to_torch_expr(arg, indent, current_loop_var) for arg in args ] idx = ast_to_torch_expr(index_ast, indent, current_loop_var) if func_name == "grad": # grad(output, input)[i] -> compute_grad(output, input)[i] return f"compute_grad({', '.join(arg_strs)})[int({idx})]" else: return f"{func_name}({', '.join(arg_strs)})[int({idx})]" elif op == "imaginary": # If we're inside a for-expr whose loop var is 'i', emit 'i'. # current_loop_var may be a string (single var) or set (nested vars). active = (current_loop_var if isinstance(current_loop_var, set) else (set({current_loop_var}) if current_loop_var else set())) if "i" in active: return "i" # Use torch.tensor(1j) so it can be used with torch.exp return "torch.tensor(1j)" elif op == "for_expr": # active_vars accumulates all enclosing loop var names # to handle nested loops loop_var = node[1] size_expr = node[2] body_expr = node[3] outer_active = ( current_loop_var if isinstance(current_loop_var, set) else (set({current_loop_var}) if current_loop_var else set())) active_vars = outer_active | set({loop_var}) n_code = ast_to_torch_expr(size_expr, indent, outer_active or None) body_code = ast_to_torch_expr(body_expr, indent, active_vars) tmp = f"_fi_{loop_var}" return (f"torch.stack([" f"{body_code} " f"for {tmp} in range(int({n_code})) " f"for {loop_var} in [torch.tensor(float({tmp}))]])") elif op == "for_expr_range": # for i : ℕ(start, end) → body — range(start, end), end-exclusive loop_var = node[1] start_expr = node[2] end_expr = node[3] body_expr = node[4] outer_active = ( current_loop_var if isinstance(current_loop_var, set) else (set({current_loop_var}) if current_loop_var else set())) active_vars = outer_active | set({loop_var}) start_code = ast_to_torch_expr(start_expr, indent, outer_active or None) end_code = ast_to_torch_expr(end_expr, indent, outer_active or None) body_code = ast_to_torch_expr(body_expr, indent, active_vars) tmp = f"_fi_{loop_var}" return (f"torch.stack([" f"{body_code} " f"for {tmp} in range(int({start_code}), int({end_code})) " f"for {loop_var} in [torch.tensor(float({tmp}))]])") elif op == "equation_string": return repr(node[1]) elif op == "string": # Equation string literal return repr(node[1]) # Adds expression tags from ELF features elf_result = REGISTRY.dispatch_forward( op, node, to_expr=lambda n: ast_to_torch_expr(n, indent, current_loop_var), loop_var=current_loop_var, indent=indent, ) if elf_result is not None: return elf_result if call_fallback is not None: return call_fallback return f"/* unknown: {node} */"
[docs] def condition_to_expr(cond: ASTNode, current_loop_var: str | set[str] | None = None) -> str: """Convert a condition AST node to a Python boolean expression string. Parameters ---------- cond : tuple[str, ...] A condition tuple like ``("cond_eq", left, right)``. current_loop_var : str or set, optional Active loop variable(s) for disambiguating the imaginary token ``i``. Returns ------- str A Python boolean expression (e.g. ``"n == 0.0"``). Examples -------- >>> from physika.utils.ast_utils import condition_to_expr >>> condition_to_expr(("cond_eq", ("var", "n"), ("num", 0.0))) 'n == 0.0' >>> condition_to_expr(("cond_lt", ("var", "x"), ("num", 1.0))) 'x < 1.0' """ op_map = { "cond_eq": "==", "cond_neq": "!=", "cond_lt": "<", "cond_gt": ">", "cond_leq": "<=", "cond_geq": ">=", } cond_t = cast(tuple[Any, ...], cond) op = cond_t[0] left = ast_to_torch_expr(cond_t[1], current_loop_var=current_loop_var) right = ast_to_torch_expr(cond_t[2], current_loop_var=current_loop_var) return f"{left} {op_map[op]} {right}"
[docs] def emit_func_loop_body( loop_body: list, indent_level: int, lines: list[str], loop_var, ) -> None: """Emit code lines for a list of ``func_loop_stmt`` AST nodes. Recurse for nested ``loop_for_range``, ``loop_if``, and ``loop_if_else`` nodes, extending ``loop_var`` with each new inner variable. ``ast_to_torch_expr`` resolves the imaginary-unit token ``i`` to the correct Python name instead of ``torch.tensor(1j)``. Parameters ---------- loop_body : list[ASTNode] ``func_loop_stmt`` nodes. ``None`` entries are skipped. Supported tags: - ``loop_assign`` - ``loop_pluseq`` - ``loop_index_pluseq`` - ``loop_for_range`` - ``loop_if`` - ``loop_if_else`` indent_level : int Current indentation depth. Each level adds 4 spaces. lines : list[str] Output list. Source lines are appended. loop_var : str or set[str] Active loop variable name(s). Grows as inner loops are entered. """ prefix = " " * indent_level active = loop_var if isinstance( loop_var, set) else ({loop_var} if loop_var else set()) for loop_stmt in loop_body: if loop_stmt is None: continue tag = loop_stmt[0] if tag == "loop_assign": _, var_name, expr = loop_stmt lines.append( f"{prefix}{var_name} = {ast_to_torch_expr(expr, current_loop_var=active)}" # noqa: E501 ) elif tag == "loop_index_assign_nd": _, arr_name, idx_list, rhs = loop_stmt indices = ", ".join( f"int({ast_to_torch_expr(idx, current_loop_var=loop_var)})" for idx in idx_list) rhs_code = ast_to_torch_expr(rhs, current_loop_var=loop_var) lines.append(f"{prefix}{arr_name}[{indices}] = {rhs_code}") elif tag == "loop_pluseq": _, var_name, expr = loop_stmt lines.append( f"{prefix}{var_name} = {var_name} + {ast_to_torch_expr(expr, current_loop_var=active)}" # noqa: E501 ) elif tag == "loop_index_pluseq": _, arr_name, idx_list, rhs = loop_stmt idx_codes = [ ast_to_torch_expr(e, current_loop_var=active) for e in idx_list ] rhs_code = ast_to_torch_expr(rhs, current_loop_var=active) lines.append( f"{prefix}{arr_name}[{', '.join(f'int({c})' for c in idx_codes)}] += {rhs_code}" # noqa: E501 ) elif tag == "loop_for_range": _, inner_var, start_expr, end_expr, inner_body = loop_stmt start_code = ast_to_torch_expr(start_expr, current_loop_var=active) end_code = ast_to_torch_expr(end_expr, current_loop_var=active) lines.append( f"{prefix}for {inner_var} in range(int({start_code}), int({end_code})):" # noqa: E501 ) emit_func_loop_body(inner_body, indent_level + 1, lines, active | {inner_var}) elif tag == "loop_if": _, cond, then_body = loop_stmt lines.append( f"{prefix}if {condition_to_expr(cond, current_loop_var=active)}:" # noqa: E501 ) emit_func_loop_body(then_body, indent_level + 1, lines, active) elif tag == "loop_if_else": _, cond, then_body, else_body = loop_stmt lines.append( f"{prefix}if {condition_to_expr(cond, current_loop_var=active)}:" # noqa: E501 ) emit_func_loop_body(then_body, indent_level + 1, lines, active) lines.append(f"{prefix}else:") emit_func_loop_body(else_body, indent_level + 1, lines, active) else: result = REGISTRY.dispatch_forward( tag, loop_stmt, to_expr=ast_to_torch_expr, current_loop_var=active, ) if result is not None: lines.append(f"{prefix}{result}")
# Code generators (function / class / statement)
[docs] def emit_body_stmts( stmts: list[ASTNode], indent_level: int, lines: list[str], known_vars: list[str], equation_vars: set[str], generate_solve_call: Callable[[ASTNode], str], scalar_only: bool = False, expr_fn=ast_to_torch_expr, _equation_vars: set[str] | None = None, ) -> None: """Recursively emit Python code lines for a function body. Converts a sequence of ``body_decl``, ``body_assign``, ``body_tuple_unpack``, ``body_if_return``, ``body_if_else_return``, ``body_if_else``, or ``body_if`` AST nodes into indented Python source lines and appends them to `lines`. Parameters ---------- stmts: list[ASTNode] Sequence of ``body_decl``, ``body_assign``, ``body_tuple_unpack``, ``body_if_return``, ``body_if_else_return``, ``body_if_else``, or ``body_if`` AST tuples to emit. ``None`` entries are skipped. indent_level: int Nesting depth. 1 if directly inside the function body, `indent_level` is 2 if inside an if/else branch, etc.). Each level adds four spaces. lines: list[str] Output list; generated source lines are appended here. known_vars: list[str] Running list of variable names in scope. Extended in place when new locals are declared. equation_vars: set[str] Set of variable names bound to equation strings (used to exclude them from ``solve()`` keyword arguments). Updated in place. generate_solve_call: Callable[[ASTNode], str] Callable that converts an expression AST to a Python string, expanding ``solve(...)`` calls with the current `known_vars`. expr_fn : callable, optional Expression code-generator; defaults to ``ast_to_torch_expr``. _equation_vars : set, optional Internal — tracks variables bound to equation strings so they are excluded from ``solve()`` keyword arguments. Pass ``None`` (default) to create a fresh set for this call. Examples -------- >>> from physika.utils.ast_utils import emit_body_stmts >>> from physika.utils.ast_utils import ast_to_torch_expr >>> lines = [] >>> known_vars = ["x"] >>> equation_vars = set() >>> emit_body_stmts( ... [("body_assign", "y", ("mul", ("var", "x"), ("num", 2.0)))], ... 1, lines, known_vars, equation_vars, ast_to_torch_expr, ... ) >>> lines [' y = (x * 2.0)'] """ if expr_fn is None: expr_fn = ast_to_torch_expr if _equation_vars is None: _equation_vars = set() prefix = " " * indent_level for stmt in stmts: if not isinstance(stmt, tuple): continue stmt_op = stmt[0] if stmt_op == "body_decl": _, var_name, var_type, expr = stmt if isinstance(expr, tuple) and expr[0] == "string": equation_vars.add(var_name) expr_code = generate_solve_call(expr) lines.append(f"{prefix}{var_name} = {expr_code}") known_vars.append(var_name) elif stmt_op == "body_assign": _, var_name, expr = stmt expr_code = generate_solve_call(expr) lines.append(f"{prefix}{var_name} = {expr_code}") known_vars.append(var_name) elif stmt_op == "body_index_assign": _, name, idx, val = stmt idx_code = ast_to_torch_expr(idx) val_code = ast_to_torch_expr(val) lines.append(f"{prefix}{name}[int({idx_code})] = {val_code}") elif stmt_op == "body_index_assign_nd": _, name, indices, val = stmt idx_code = ", ".join(f"int({ast_to_torch_expr(i)})" for i in indices) val_code = ast_to_torch_expr(val) lines.append(f"{prefix}{name}[{idx_code}] = {val_code}") elif stmt_op == "body_tuple_unpack": _, var_names, expr = stmt expr_code = generate_solve_call(expr) lines.append(f"{prefix}{', '.join(var_names)} = {expr_code}") known_vars.extend(var_names) elif stmt_op == "body_if_return": _, cond, return_expr = stmt cond_code = condition_to_expr(cond) return_code = ast_to_torch_expr(return_expr) lines.append(f"{prefix}if {cond_code}:") lines.append(f"{prefix} return {return_code}") elif stmt_op == "body_if_else_return": _, cond, then_expr, else_expr = stmt cond_code = condition_to_expr(cond) then_code = ast_to_torch_expr(then_expr) else_code = ast_to_torch_expr(else_expr) if scalar_only: # Scalar functions: use Python if/else so recursion works lines.append(f"{prefix}if {cond_code}:") lines.append(f"{prefix} return {then_code}") lines.append(f"{prefix}else:") lines.append(f"{prefix} return {else_code}") else: # Vector functions: use torch.where for elementwise # differentiability lines.append( f"{prefix}return torch.where(torch.as_tensor({cond_code}), {then_code}, {else_code})" # noqa: E501 ) elif stmt_op == "body_if_else": _, cond, then_stmts, else_stmts = stmt cond_code = condition_to_expr(cond) lines.append(f"{prefix}if {cond_code}:") emit_body_stmts(then_stmts, indent_level + 1, lines, known_vars, equation_vars, generate_solve_call, scalar_only) lines.append(f"{prefix}else:") emit_body_stmts(else_stmts, indent_level + 1, lines, known_vars, equation_vars, generate_solve_call, scalar_only) elif stmt_op == "body_if": _, cond, then_stmts = stmt cond_code = condition_to_expr(cond) lines.append(f"{prefix}if {cond_code}:") emit_body_stmts(then_stmts, indent_level + 1, lines, known_vars, equation_vars, generate_solve_call, scalar_only) elif stmt_op == "body_for": _, loop_var, loop_body, indexed_arrays = stmt if indexed_arrays: lines.append( f"{prefix}for {loop_var} in range(len({indexed_arrays[0]})):" # noqa: E501 ) else: lines.append(f"{prefix}for {loop_var} in range(n):") emit_func_loop_body(loop_body, indent_level + 1, lines, loop_var) elif stmt_op == "body_for_range": _, loop_var, start_expr, end_expr, loop_body = stmt start_code = ast_to_torch_expr(start_expr) end_code = ast_to_torch_expr(end_expr) lines.append( f"{prefix}for {loop_var} in range(int({start_code}), int({end_code})):" # noqa: E501 ) emit_func_loop_body(loop_body, indent_level + 1, lines, loop_var) elif stmt_op == "body_zeros_decl": # Type annotation for an accumulation target. # The paired body_for_accum emits the `torch.stack` expression that # defines the tensor. pass elif stmt_op == "body_for_map": _, loop_vars, loop_body = stmt assign_stmt = loop_body[0] _, tensor_name, lhs_idx_exprs, rhs_expr = assign_stmt ranges = { v: _infer_range(v, rhs_expr, tensor_name) or f"# range unknown for {v}" for v in loop_vars } rhs_code = ast_to_torch_expr(rhs_expr, current_loop_var=set(loop_vars)) inner_expr = rhs_code for v in reversed(loop_vars): inner_expr = (f"torch.stack([{inner_expr}" f" for {v} in range({ranges[v]})])") lines.append(f"{prefix}{tensor_name} = {inner_expr}") elif stmt_op == "body_for_accum": # Generates one differentiable torch.stack per += target. # Emits one `name = torch.stack(...)` # line per unique += target tensor. _, loop_vars, loop_body = stmt active = set(loop_vars) # Collect all unique accumulation targets accums: dict = {} for s in loop_body: if s and s[0] == "loop_index_pluseq": _, name, idx_list, rhs = s if name not in accums: accums[name] = (idx_list, rhs) if not accums: raise ValueError( "body_for_accum has no loop_index_pluseq statement") # Generate one differentiable torch.stack expression per target # tensor for tensor_name, (lhs_idx_exprs, rhs_expr) in accums.items(): ranges = { v: _infer_range(v, rhs_expr, tensor_name) or f"# range unknown for {v}" for v in loop_vars } lhs_vars = [ n for n in (_lhs_var_name(e) for e in lhs_idx_exprs) if n ] reduction_vars = [v for v in loop_vars if v not in lhs_vars] rhs_code = ast_to_torch_expr(rhs_expr, current_loop_var=active) inner_expr = rhs_code for rv in reversed(reduction_vars): inner_expr = (f"torch.sum(torch.stack([{inner_expr}" f" for {rv} in range({ranges[rv]})]))") for ov in reversed(lhs_vars): inner_expr = (f"torch.stack([{inner_expr}" f" for {ov} in range({ranges[ov]})])") lines.append(f"{prefix}{tensor_name} = {inner_expr}") else: elf_line = REGISTRY.dispatch_forward( stmt_op, stmt, to_expr=expr_fn, ) if elf_line is not None: lines.append(f"{prefix}{elf_line}") var_name = stmt[1] if len(stmt) > 1 and isinstance( stmt[1], str) else None if var_name: known_vars.append(var_name)
[docs] def generate_function(name: str, func_def: dict[str, Any]) -> str: """Generate a Python/PyTorch function definition from a function AST. Translates a Physika function (params, body statements, return expression) into a valid Python function definition string. If the function body contains a ``solve()`` call, local known-variable tracking is used to pass all in-scope variables as keyword arguments to ``solve``. Parameters ---------- name : str The function identifier (e.g. ``"sigma"``, ``"U"``). func_def : dict[str, ASTNode] A dict from ``unified_ast["functions"]`` with keys ``"params"`` (list of ``(name, type)`` pairs), ``"body"`` (return expression AST), and optionally ``"statements"`` (list of body statement ASTs). Returns ------- str A multi-line Python source string containing the complete function definition. Examples -------- >>> from physika.utils.ast_utils import generate_function >>> func_def = { ... "params": [("x", "ℝ")], ... "body": ("call", "exp", [("var", "x")]), ... "statements": [], ... } >>> print(generate_function("f", func_def)) # noqa: E501 def f(x): return torch.exp(x if isinstance(x, torch.Tensor) else torch.tensor(float(x))) """ params = func_def["params"] body = func_def["body"] statements = func_def.get("statements", []) # Build parameter list param_strs = [] param_names = [] for param_name, param_type in params: param_strs.append(f"{param_name}") param_names.append(param_name) lines = [f"def {name}({', '.join(param_strs)}):"] # Track known variables (params + locals) known_vars = list(param_names) # Track equation string variable names equation_vars: set[str] = set() # Helper to generate solve call with known variables # (kept local: it accumulates known_vars/equation_vars as statements # are processed) def generate_solve_call(expr): if isinstance(expr, tuple) and expr[0] == "call" and expr[1] == "solve": args = expr[2] arg_strs = [ast_to_torch_expr(arg) for arg in args] # Add known variables as keyword arguments (exclude equation vars) return f"solve({', '.join(arg_strs)})" return ast_to_torch_expr(expr) # Use if/else (not torch.where) when all params are scalars # — allows recursion scalar_only = all(pt == "\u211d" for _, pt in params) # Generate body statements emit_body_stmts(statements, 1, lines, known_vars, equation_vars, generate_solve_call, scalar_only) # Generate for-loop body if func_def.get("has_loop"): init_stmts = func_def.get("init_stmts", []) loop_var = func_def.get("loop_var", "k") indexed_arrays = func_def.get("loop_indexed_arrays", []) loop_body = func_def.get("loop_body", []) # Emit pre-loop initialisation for stmt in init_stmts: if stmt is None: continue if stmt[0] == "init_assign": _, var_name, expr = stmt expr_code = ast_to_torch_expr(expr) lines.append(f" {var_name} = {expr_code}") # Emit loop header — range inferred from the first indexed array if indexed_arrays: lines.append( f" for {loop_var} in range(len({indexed_arrays[0]})):") else: lines.append(f" for {loop_var} in range(n):") # Emit loop body statements for stmt in loop_body: if stmt is None: continue if stmt[0] == "loop_assign": _, var_name, expr = stmt expr_code = ast_to_torch_expr(expr, current_loop_var=loop_var) lines.append(f" {var_name} = {expr_code}") elif stmt[0] == "loop_pluseq": _, var_name, expr = stmt expr_code = ast_to_torch_expr(expr, current_loop_var=loop_var) lines.append(f" {var_name} = {var_name} + {expr_code}") # Generate return statement only when there is a final expression if body is not None: body_code = ast_to_torch_expr(body) lines.append(f" return {body_code}") return "\n".join(lines)
[docs] def emit_for_stmts( stmts: list[ASTNode], indent: int = 4, loop_var: str | set[str] | None = None, ) -> list[str]: """Emit Python code for a top-level for-loop or if-else branch body. Handles ``for_assign``, ``for_pluseq``, ``for_index_assign``, ``for_call``, and nested ``for_loop`` / ``for_loop_range`` nodes. Recurses for nested loops, increasing indentation by 4 spaces per level. Parameters ---------- stmts: list[ASTNode] List of ``for_assign``, ``for_pluseq``, ``for_index_assign``, ``for_call``, ``for_loop`` or ``for_loop_range`` AST nodes. indent: int Integer representing the whitespace in emitted line. loop_var: str or None Enclosing loop variable name, forwarded to ``ast_to_torch_expr``. Returns ------- list[str] Python code lines . Examples -------- >>> from physika.utils.ast_utils import emit_for_stmts >>> stmts = [("for_assign", "z", ("mul", ("var", "a"), ("var", "b")))] >>> emit_for_stmts(stmts, 4) [' z = (a * b)'] """ prefix = " " * indent result = [] for s in stmts: if not isinstance(s, tuple): continue body_op = s[0] if body_op == "for_assign": _, var_name, expr = s result.append( f"{prefix}{var_name} = {ast_to_torch_expr(expr, current_loop_var=loop_var)}" # noqa: E501 ) elif body_op == "for_pluseq": _, var_name, expr = s result.append( f"{prefix}{var_name} = {var_name} + {ast_to_torch_expr(expr, current_loop_var=loop_var)}" # noqa: E501 ) elif body_op == "for_index_assign_nd": _, arr_name, idx_list, rhs_expr = s indices = ", ".join( f"int({ast_to_torch_expr(idx, current_loop_var=loop_var)})" for idx in idx_list) rhs_code = ast_to_torch_expr(rhs_expr, current_loop_var=loop_var) result.append(f"{prefix}{arr_name}[{indices}] = {rhs_code}") elif body_op == "for_call": _, func_name, arg_asts = s arg_strs = [ ast_to_torch_expr(arg, current_loop_var=loop_var) for arg in arg_asts ] result.append(f"{prefix}{func_name}({', '.join(arg_strs)})") elif body_op == "for_loop_range": inner_var = s[1] start_code = ast_to_torch_expr(s[2], current_loop_var=loop_var) end_code = ast_to_torch_expr(s[3], current_loop_var=loop_var) inner_body = s[4] result.append( f"{prefix}for {inner_var} in range(int({start_code}), int({end_code})):" # noqa: E501 ) # Accumulate all active loop vars so inner body can reference # outer vars (e.g. 'i') outer_vars = loop_var if isinstance( loop_var, set) else ({loop_var} if loop_var else set()) inner_loop_var = outer_vars | {inner_var} result.extend( emit_for_stmts(inner_body, indent + 4, inner_loop_var)) elif body_op == "for_loop": inner_var = s[1] inner_body = s[2] indexed_arrays = s[3] if indexed_arrays: result.append( f"{prefix}for {inner_var} in range(len({indexed_arrays[0]})):" # noqa: E501 ) else: result.append(f"{prefix}for {inner_var} in range(n):") outer_vars = loop_var if isinstance( loop_var, set) else ({loop_var} if loop_var else set()) result.extend( emit_for_stmts(inner_body, indent + 4, outer_vars | {inner_var})) elif body_op == "for_if": _, cond, then_body = s cond_code = condition_to_expr(cond, current_loop_var=loop_var) result.append(f"{prefix}if {cond_code}:") result.extend(emit_for_stmts(then_body, indent + 4, loop_var)) elif body_op == "for_if_else": _, cond, then_body, else_body = s cond_code = condition_to_expr(cond, current_loop_var=loop_var) result.append(f"{prefix}if {cond_code}:") result.extend(emit_for_stmts(then_body, indent + 4, loop_var)) result.append(f"{prefix}else:") result.extend(emit_for_stmts(else_body, indent + 4, loop_var)) else: elf_line = REGISTRY.dispatch_forward( body_op, s, to_expr=lambda n: ast_to_torch_expr(n, current_loop_var=loop_var), indent=indent, ) if elf_line is not None: result.append(f"{prefix}{elf_line}") return result
[docs] def generate_class(name: str, class_def: dict[str, ASTNode]) -> str: """Generate an ``nn.Module`` subclass from a class AST entry. Translates a Physika class into a Python class string with ``__init__`` (wrapping tensor params as ``nn.Parameter``), ``forward`` (the lambda body, with optional loop), and an optional ``loss`` method. Class parameter references in the forward/loss bodies are rewritten to ``self.param`` via ``replace_class_params``. Parameters ---------- name : str The class identifier (e.g. ``"OneLayerNet"``). class_def : dict[str, ASTNode] A dict from ``unified_ast["classes"]`` with keys: * ``"class_params"`` — list of ``(name, type)`` pairs. * ``"lambda_params"`` — list of ``(name, type)`` pairs. * ``"body"`` — forward return expression AST. * ``"has_loop"`` (optional) — whether forward contains a loop. * ``"loop_var"``, ``"loop_body"`` (optional) — loop details. * ``"has_loss"``, ``"loss_body"``, ``"loss_params"`` (optional). Returns ------- str A multi-line Python source string containing the complete ``nn.Module`` subclass definition. Examples -------- >>> from physika.utils.ast_utils import generate_class >>> class_def = { ... "class_params": [("w", "ℝ")], ... "lambda_params": [("x", "ℝ")], ... "body": ("mul", ("var", "w"), ("var", "x")), ... "has_loop": False, "has_loss": False, ... } >>> code = generate_class("Linear", class_def) >>> "class Linear(nn.Module):" in code True """ class_params: list[tuple[str, ASTNode]] = cast(list[tuple[str, ASTNode]], class_def["class_params"]) lambda_params: list[tuple[str, ASTNode]] = cast(list[tuple[str, ASTNode]], class_def["lambda_params"]) body = class_def["body"] statements: list[ASTNode] = cast(list[ASTNode], class_def.get("statements", [])) has_loop = class_def.get("has_loop", False) loop_var = class_def.get("loop_var") loop_body: list[ASTNode] = cast(list[ASTNode], class_def.get("loop_body", [])) has_loss = class_def.get("has_loss", False) loss_body = class_def.get("loss_body") lines = [f"class {name}(nn.Module):"] # __init__ method init_params = ", ".join([p[0] for p in class_params]) lines.append(f" def __init__(self, {init_params}):") lines.append(" super().__init__()") for param_name, param_type in class_params: # Check if this is a tensor type that should be a parameter is_tensor = False if isinstance(param_type, tuple) and param_type[0] == "tensor": is_tensor = True elif param_type == "\u211d": is_tensor = True # Scalar could be a learnable parameter if is_tensor: # Handle both tensors and scalars lines.append( f" self.{param_name} = nn.Parameter(torch.tensor({param_name}).float() if not isinstance({param_name}, torch.Tensor) else {param_name}.clone().detach().float())" # noqa: E501 ) else: # Non-tensor (like function 'f' or int 'n') lines.append(f" self.{param_name} = {param_name}") # forward method (lambda) lambda_param_names = [p[0] for p in lambda_params] lines.append("") lines.append(f" def forward(self, {', '.join(lambda_param_names)}):") # Convert inputs to tensors for pname, ptype in lambda_params: if ptype == "\u211d" or ptype == "\u2115" or (isinstance( ptype, tuple) and ptype[0] == "tensor"): lines.append(f" {pname} = torch.as_tensor({pname}).float()") # Generate forward body statements (multi-statement lambda body) if statements: stmt_lines: list[str] = [] known_vars = [p[0] for p in lambda_params] equation_vars: set[str] = set() scalar_only = all(pt == "\u211d" for _, pt in lambda_params) emit_body_stmts(statements, 2, stmt_lines, known_vars, equation_vars, ast_to_torch_expr, scalar_only) for line in stmt_lines: lines.append(replace_class_params(line, class_params)) # Generate loop if present if has_loop and loop_body: lines.append( f" n = int(self.n) if hasattr(self, 'n') else self.{class_params[-1][0]}.shape[0] if hasattr(self.{class_params[-1][0]}, 'shape') else 2" # noqa: E501 ) lines.append(f" for {loop_var} in range(n):") for stmt in loop_body: if isinstance(stmt, tuple) and stmt[0] == "loop_assign": var_name = stmt[1] expr = stmt[2] expr_code = ast_to_torch_expr(expr) expr_code = replace_class_params(expr_code, class_params) lines.append(f" {var_name} = {expr_code}") # Generate return body_code = ast_to_torch_expr(body) body_code = replace_class_params(body_code, class_params) lines.append(f" return {body_code}") # loss method if present if has_loss and loss_body: loss_params: list[tuple[str, ASTNode]] = cast( list[tuple[str, ASTNode]], class_def.get("loss_params", [("y", "\u211d"), ("target", "\u211d")])) loss_param_names = [p[0] for p in loss_params] loss_stmts: list[ASTNode] = cast(list[ASTNode], class_def.get("loss_statements", [])) # Check if loss uses grad — also scan loss body statements loss_uses_grad = ast_uses_func(loss_body, "grad") if not loss_uses_grad: for s in loss_stmts: if ast_uses_func(s, "grad"): loss_uses_grad = True break if loss_uses_grad and lambda_param_names: # Add the input parameter (x) to loss params input_param = lambda_param_names[0] # typically 'x' lines.append("") lines.append( f" def loss(self, {', '.join(loss_param_names)}, {input_param}):" # noqa: E501 ) else: lines.append("") lines.append(f" def loss(self, {', '.join(loss_param_names)}):") # Emit loss body statements for stmt in loss_stmts: if not isinstance(stmt, tuple): continue stmt_op = stmt[0] if stmt_op == "body_decl": _, var_name, var_type, expr = stmt expr_code = ast_to_torch_expr(expr) expr_code = replace_class_params(expr_code, class_params) lines.append(f" {var_name} = {expr_code}") elif stmt_op == "body_assign": _, var_name, expr = stmt expr_code = ast_to_torch_expr(expr) expr_code = replace_class_params(expr_code, class_params) lines.append(f" {var_name} = {expr_code}") elif stmt_op == "body_tuple_unpack": _, var_names, expr = stmt expr_code = ast_to_torch_expr(expr) expr_code = replace_class_params(expr_code, class_params) lines.append(f" {', '.join(var_names)} = {expr_code}") loss_code = ast_to_torch_expr(loss_body) loss_code = replace_class_params(loss_code, class_params) lines.append(f" return {loss_code}") return "\n".join(lines)
[docs] def generate_statement(stmt: ASTNode, grad_target_vars: set[str]) -> str | None: """Generate a PyTorch code string for a program-level statement. Handles ``decl`` (variable declaration), ``assign`` (reassignment), ``expr`` (bare expression — wrapped in ``physika_print`` unless it is a side-effect call like ``simulate``/``animate``), ``for_loop``, and skips ``func_def``/``class_def`` (already emitted by ``from_ast_to_torch``). Variables whose names appear in *grad_target_vars* are initialised with ``requires_grad=True`` so that ``grad()`` can differentiate through them. Parameters ---------- stmt : ASTNode An AST statement tuple (e.g. ``("decl", name, type, expr, lineno)``) or ``None``. grad_target_vars : set[str] Variable names used as differentiation targets in ``grad()`` calls. Collected by ``collect_grad_targets`` during the analysis pass. Returns ------- str or None A Python source string for the statement, or ``None`` if the statement should be skipped (``func_def``, ``class_def``, or ``None`` input). Examples -------- >>> from physika.utils.ast_utils import generate_statement >>> generate_statement(("decl", "x", "ℝ", ("num", 3.0), 1), set()) 'x = 3.0' >>> generate_statement(("decl", "t", "ℝ", ("num", 0.0), 2), {"t"}) 't = torch.tensor(0.0, requires_grad=True)' >>> generate_statement(("expr", ("var", "x"), 0), set()) 'physika_print(x)' """ if not isinstance(stmt, tuple): return None op = stmt[0] if op == "decl": name = stmt[1] type_spec = stmt[2] expr = stmt[3] expr_code = ast_to_torch_expr(expr) # Variables used as grad targets need to be tensors with requires_grad if name in grad_target_vars: if type_spec == "\u211d": return f"{name} = torch.tensor({expr_code}, requires_grad=True)" # noqa: E501 if isinstance(type_spec, tuple) and type_spec[0] == "tensor": if type_spec[1] == "ℂ": return f"{name} = torch.as_tensor({expr_code}, dtype=torch.complex64).requires_grad_(True)" # noqa: E501 return f"{name} = torch.as_tensor({expr_code}).float().requires_grad_(True)" # noqa: E501 # Tensor value if isinstance(type_spec, tuple) and type_spec[0] == "tensor": dtype = type_spec[1] if dtype == "ℂ": return (f"{name} = " f"torch.as_tensor({expr_code}, dtype=torch.complex64)") if dtype == "ℝ": return (f"{name} = " f"torch.as_tensor({expr_code}, dtype=torch.float32)") # Scalar value if type_spec == "\u2124": return f"{name} = int({expr_code})" if type_spec == "\u2102": return f"{name} = torch.tensor({expr_code}, dtype=torch.complex64)" return f"{name} = {expr_code}" elif op == "assign": name = stmt[1] expr = stmt[2] expr_code = ast_to_torch_expr(expr) return f"{name} = {expr_code}" elif op == "index_assign": name, idx, val = stmt[1], stmt[2], stmt[3] idx_code = ast_to_torch_expr( ("var", idx)) if isinstance(idx, str) else ast_to_torch_expr(idx) return f"{name}[int({idx_code})] = {ast_to_torch_expr(val)}" elif op == "index_assign_nd": name, indices, val = stmt[1], stmt[2], stmt[3] idx_code = ", ".join(f"int({ast_to_torch_expr(i)})" for i in indices) return f"{name}[{idx_code}] = {ast_to_torch_expr(val)}" elif op == "expr": expr = stmt[1] expr_code = ast_to_torch_expr(expr) # Don't wrap side-effect-only calls in physika_print if isinstance(expr, tuple) and expr[0] == "call" and expr[1] in ("simulate", "animate"): return expr_code return f"physika_print({expr_code})" elif op == "symbol_decl": name = stmt[1] return f"{name} = sp.Symbol('{name}')" elif op == "symbol_decl_multi": return "\n".join(f"{name} = sp.Symbol('{name}')" for name in stmt[1]) elif op == "function_decl": name = stmt[1] return f"{name} = sp.Function('{name}')" elif op == "function_decl_multi": return "\n".join(f"{name} = sp.Function('{name}')" for name in stmt[1]) elif op == "equation_decl": name = stmt[1] lhs = ast_to_torch_expr(stmt[2]) rhs = ast_to_torch_expr(stmt[3]) return f"{name} = sp.Eq({lhs}, {rhs})" elif op == "func_def": return None # Already generated elif op == "class_def": return None # Already generated elif op == "for_loop": # For loop: ("for_loop", loop_var, body_statements, # indexed_arrays[, lineno]) loop_var = stmt[1] body_statements = stmt[2] indexed_arrays = stmt[3] if indexed_arrays: header = f"for {loop_var} in range(len({indexed_arrays[0]})):" else: header = f"for {loop_var} in range(n): # TODO: determine n" lines = [header] + emit_for_stmts(body_statements, 4, loop_var) return "\n".join(lines) elif op == "for_loop_range": # Explicit-range for loop: ("for_loop_range", loop_var, start_expr, # end_expr, body_stmts, lineno) loop_var = stmt[1] start_code = ast_to_torch_expr(stmt[2]) end_code = ast_to_torch_expr(stmt[3]) body_statements = stmt[4] lines = [ f"for {loop_var} in range(int({start_code}), int({end_code})):" ] lines += emit_for_stmts(body_statements, 4, loop_var) return "\n".join(lines) elif op in ("if_else", "if_only"): cond = stmt[1] then_stmts = stmt[2] cond_code = condition_to_expr(cond) branch_lines = [f"if {cond_code}:"] branch_lines.extend(emit_for_stmts(then_stmts)) if op == "if_else": else_stmts = stmt[3] branch_lines.append("else:") branch_lines.extend(emit_for_stmts(else_stmts)) return "\n".join(branch_lines) # Adds ELF features statement tags elf_result = REGISTRY.dispatch_forward( op, stmt, to_expr=lambda n: ast_to_torch_expr(n), indent=0, ) if elf_result is not None: return elf_result return f"# Unknown: {stmt}"
[docs] def build_unified_ast( program_ast: list[ASTNode], symbol_table: dict[str, dict[str, Any]], print_ast: bool = False, ) -> dict[str, Any]: """Build a unified AST combining definitions and program statements. Merges the flat ``program_ast`` (list of statement tuples produced by the parser) with the ``symbol_table`` (function and class definitions accumulated during parsing) into a single dict with three sections: ``"functions"``, ``"classes"``, and ``"program"``. Parameters ---------- program_ast : list[ASTNode] The list of top-level statement AST tuples returned by ``parser.parse()``. symbol_table : dict[str, dict[str, Any]] The parser's symbol table mapping names to ``{"type": "function"|"class", "value": ...}`` entries. print_ast : bool, default False If ``True``, print the unified AST to stdout for debugging. Returns ------- dict[str, dict[str, ASTNode] | list[ASTNode]] A dict with keys: * ``"functions"`` — ``{name: func_def, ...}`` * ``"classes"`` — ``{name: class_def, ...}`` * ``"program"`` — ``[stmt, ...]`` Examples -------- >>> from physika.utils.ast_utils import build_unified_ast >>> ast = [("expr", ("num", 42.0), 1)] >>> sym = {} >>> unified = build_unified_ast(ast, sym) >>> unified["program"] [('expr', ('num', 42.0), 1)] >>> unified["functions"] {} """ unified: dict[str, Any] = {"functions": {}, "classes": {}, "program": []} # Extract functions and classes from symbol table for name, entry in symbol_table.items(): if entry["type"] == "function": unified["functions"][name] = entry["value"] elif entry["type"] == "class": unified["classes"][name] = entry["value"] # Add program statements for stmt in program_ast: if stmt is not None: unified["program"].append(stmt) if print_ast: print("\n=== UNIFIED AST ===") print(print_unified_ast(unified)) return unified