Source code for physika.type_checker

from typing import Any, Dict

from physika.utils.types import (
    check_function,
    check_statement,
    check_class,
    TInstance,
    counter,
)

# Type aliases used in annotations throughout this module.
ASTExpr = Any  # tagged tuple, scalar, or None
TypeSpec = Any  # "ℝ", "ℕ", ("tensor", [...]), ("func_type", ...), None
UnifiedAST = Dict[str, Any]


[docs] class TypeChecker: """ Type checker for Physika programs. The Hindley-Milner algorithm rests on two main steps: **Substitution** performs a mapping ``{αN: concrete_type}`` from unknown type variables (TVar `αN` or TDim `δN`) to valid types. When ``unify`` determines that ``αN`` must equal some type ``T``, it extends the substitution with the binding ``αN → T``. Calling ``s.apply(t)`` replaces every bound type variable in ``t`` with its mapped type, following chains of bindings until a concrete type is reached. **Unification** (``unify(t1, t2, s)``) uses the accumulated version of ``s`` that makes ``s.apply(t1) == s.apply(t2)``. If either side is a free ``TVar``, a new binding is added to ``s``. If both sides are concrete types of the same constructor (arrays, matrices, tensors as ``TTensor`` types), their components are unified recursively. If the shapes are incompatible (e.g. ``ℝ`` vs. ``ℝ[3]``), the mismatch is recorded as a type error. Physika's type checker performs three passes over the unified AST: 1. **Signature registration**: All function and class signatures are stored in ``func_env`` and ``class_env`` before any body is examined. Class constructors are stored in ``func_env`` as ``(field_types, TInstance(name))``. 2. **Body checking** (``check_function``, ``check_class``): For each ``def`` and ``class``, ``infer_stmts`` walks statements in order, threading ``s`` through every expression to build a local type environment. The return expression is inferred and unified against the declared return type. A mismatch is recorded as an error prefixed with the function or class name. 3. **Program statement checking** (``check_statement``): Top-level stmts nodes are checked in source order. The line number is read from the last element of each statement tuple and prepended to error messages. Type mismatches are accumulated in ``self.errors`` as plain strings. Parameters ---------- unified_ast : dict The unified AST dict produced by ``build_unified_ast()``, with keys ``"functions"``, ``"classes"``, and ``"program"``. Examples -------- >>> # Example 1 >>> # No errors >>> from physika.type_checker import TypeChecker >>> ast = { ... "functions": {}, ... "classes": {}, ... "program": [("decl", "x", "ℝ", ("num", 1.0), 1)], ... } >>> TypeChecker(ast).run() [] >>> # Example 2 >>> # function called with wrong number of arguments: >>> fdef = { ... "params": [("x", "ℝ"), ("y", "ℝ")], ... "statements": [], ... "body": ("add", ("var", "x"), ("var", "y")), ... "return_type": "ℝ", ... } >>> ast = { ... "functions": {"add2": fdef}, ... "classes": {}, ... "program": [("expr", ("call", "add2", [("num", 1.0)]), 3)], ... } >>> TypeChecker(ast).run() ["Line 3: Function 'add2' expects 2 args, got 1"] """ def __init__(self, unified_ast: dict) -> None: self.unified_ast = unified_ast self.errors: list[str] = [] self.type_env: dict = {} self.func_env: dict = {} self.class_env: dict = {}
[docs] def run(self) -> list[str]: """Run type inference over the full unified AST. Three passes: 1. Register all function and class signatures. 2. Check function bodies 3. Check class bodies. 4. Check top-level statements. Returns ------- list[str] Accumulated type error messages. Empty if the program is well-typed. Examples -------- >>> # No errors >>> from physika.type_checker import TypeChecker >>> ast = { ... "functions": {}, ... "classes": {}, ... "program": [("decl", "x", "ℝ", ("num", 1.0), 1)], ... } >>> TypeChecker(ast).run() [] """ counter.reset() for name, fdef in self.unified_ast["functions"].items(): params = fdef["params"] self.func_env[name] = ([pt for _, pt in params], fdef.get("return_type")) for name, cdef in self.unified_ast["classes"].items(): all_fields = (list(cdef.get("class_params", [])) + list(cdef.get("fields", []))) methods = cdef.get("methods", []) self.class_env[name] = { "fields": all_fields, "methods": { m["name"]: { "params": m.get("params", []), "return_type": m.get("return_type") } for m in methods }, } self.func_env[name] = ( [pt for _, pt in cdef.get("class_params", [])], TInstance(name), ) for name, fdef in self.unified_ast["functions"].items(): check_function(name, fdef, self.func_env, self.class_env, self.errors.append) for name, cdef in self.unified_ast["classes"].items(): check_class(name, cdef, self.func_env, self.class_env, self.errors.append) for stmt in self.unified_ast["program"]: if stmt and stmt[0] not in ("func_def", "class_def"): check_statement(stmt, self.type_env, self.func_env, self.class_env, self.errors.append) return self.errors