Source code for physika.codegen

from typing import Dict, Set, Any

from physika.utils.ast_utils import (ast_uses_solve, ast_uses_func,
                                     collect_grad_targets, generate_function,
                                     generate_class, generate_statement,
                                     ast_uses_sympy, ast_to_torch_expr)
from physika.elf import REGISTRY


[docs] def from_ast_to_torch(unified_ast: Dict[str, Any], print_code: bool = True) -> str: """Convert a unified AST into a complete, executable Python/PyTorch source string. This conversion is done in two passes: 1. **Analysis pass** — walks the AST to determine which ``runtime.py`` helpers (``solve``, ``train``, ``evaluate``, ``compute_grad``, ``simulate``, ``animate``, etc) are referenced, and collects variables used as ``grad()`` differentiation targets. 2. **Code-generation pass** — uses ``generate_function``, ``generate_class``, and ``generate_statement`` (from ``utils.ast_utils``) to emit Python source for each AST entry, preceded by import header. The returned string is ready to be executed with ``exec()``. Parameters ---------- unified_ast : Dict[str, Any] The unified AST dict produced by ``build_unified_ast()``, with keys: * ``"functions"`` — ``Dict[str, dict]`` mapping function names to their AST definitions (params, body, statements). * ``"classes"`` — ``Dict[str, dict]`` mapping class names to their AST definitions (class_params, lambda_params, body, loss_body, …). * ``"program"`` — ``List[tuple]`` of top-level statement AST nodes (decl, assign, expr, for_loop, func_def, class_def). print_code : bool, default True If ``True``, print the generated code. Returns ------- str : A complete Python/PyTorch source string containing ``import`` statements, function definitions, ``nn.Module`` class definitions, and program-level statements. Variables that appear as ``grad()`` targets are initialised with ``requires_grad=True``. Examples -------- >>> # Example #1: simple expression >>> unified_ast = { ... "functions": {}, ... "classes": {}, ... "program": [("expr", ("num", 42.0), 1)], ... } >>> code = from_ast_to_torch(unified_ast, print_code=False) >>> "import torch" in code True >>> "physika_print(42.0)" in code True >>> print(code) import torch import torch.nn as nn import torch.optim as optim <BLANKLINE> from physika.runtime import physika_print <BLANKLINE> # === Program === physika_print(42.0) >>> # Example #2: function definition and call >>> unified_ast = { ... "functions": { ... "f": {"params": [("x", "ℝ")], "body": ("call", "exp", ... [("var", "x")]), "statements": []}, ... }, ... "classes": {}, ... "program": [("expr", ("call", "f", [("num", 1.0)]), 2)], ... } >>> code = from_ast_to_torch(unified_ast, print_code=False) >>> "def f(x):" in code True >>> "torch.exp" in code True >>> print(code) # noqa: E501 import torch import torch.nn as nn import torch.optim as optim <BLANKLINE> from physika.runtime import physika_print <BLANKLINE> # === Functions === def f(x): return torch.exp(x if isinstance(x, torch.Tensor) else torch.tensor(float(x))) <BLANKLINE> # === Program === physika_print(f(1.0)) """ code_lines = [] # Analysis pass: determine which helpers are needed needs_solve = any(ast_uses_solve(stmt) for stmt in unified_ast["program"]) for func_def in unified_ast["functions"].values(): if ast_uses_solve(func_def.get("body")) or any( ast_uses_solve(s) for s in func_def.get("statements", [])): needs_solve = True break needs_train = any( ast_uses_func(stmt, "train") for stmt in unified_ast["program"]) needs_evaluate = any( ast_uses_func(stmt, "evaluate") for stmt in unified_ast["program"]) needs_simulate = any( ast_uses_func(stmt, "simulate") for stmt in unified_ast["program"]) needs_animate = any( ast_uses_func(stmt, "animate") for stmt in unified_ast["program"]) needs_sympy = any(ast_uses_sympy(stmt) for stmt in unified_ast["program"]) # Collect variables used as differentiation targets in grad() calls grad_target_vars: Set[str] = set() for stmt in unified_ast["program"]: collect_grad_targets(stmt, grad_target_vars) # Check for grad usage in classes and program statements needs_grad = False for class_def in unified_ast["classes"].values(): if ast_uses_func(class_def.get("loss_body"), "grad"): needs_grad = True break if ast_uses_func(class_def.get("body"), "grad"): needs_grad = True break if any( ast_uses_func(s, "grad") for s in class_def.get("statements", [])): needs_grad = True break if any( ast_uses_func(s, "grad") for s in class_def.get("loss_statements", [])): needs_grad = True break if not needs_grad: for stmt in unified_ast["program"]: if ast_uses_func(stmt, "grad"): needs_grad = True break # Code generation # Header code_lines.append("import torch") code_lines.append("import torch.nn as nn") code_lines.append("import torch.optim as optim") if needs_solve: code_lines.append("import re") code_lines.append("") # Import helpers from runtime.py imports = ["from physika.runtime import physika_print"] if needs_solve: imports.append("from physika.runtime import solve") if needs_train: imports.append("from physika.runtime import train") if needs_evaluate: imports.append("from physika.runtime import evaluate") if needs_grad: imports.append("from physika.runtime import compute_grad") if needs_simulate: imports.append("from physika.runtime import simulate") if needs_animate: imports.append("from physika.runtime import animate") if needs_sympy: imports.append("import sympy as sp") code_lines.append("\n".join(imports)) code_lines.append("") # Generate functions if unified_ast["functions"]: code_lines.append("# === Functions ===") for name, func_def in unified_ast["functions"].items(): code_lines.append(generate_function(name, func_def)) code_lines.append("") # Generate classes if unified_ast["classes"]: code_lines.append("# === Classes ===") for name, class_def in unified_ast["classes"].items(): if REGISTRY.features != []: node = ("class_def", name, class_def) class_code = REGISTRY.dispatch_forward( "class_def", node, to_expr=ast_to_torch_expr) assert class_code is not None code_lines.append(class_code) else: code_lines.append(generate_class(name, class_def)) code_lines.append("") # Generate program statements code_lines.append("# === Program ===") for stmt in unified_ast["program"]: stmt_code = generate_statement(stmt, grad_target_vars) if stmt_code: code_lines.append(stmt_code) # Join all code generated_code = "\n".join(code_lines) if print_code: print("\n=== Physika generated Pytorch code ===") print(generated_code) print("=== End Pytorch code ===\n") return generated_code