Source code for physika.utils.print_utils

from __future__ import annotations

import sys
from typing import Any

import torch
import torch.nn as nn






def _pformat(value: Any, indent: int = 0) -> str:
    """Pretty-format an AST value with indentation.

    Recursively formats tuples, lists, dicts, and scalars into readable,
    indented string. Short collections that fit within 80 columns are
    kept on a single line and longer ones are expanded vertically.

    Parameters
    ----------
    value : Any
        The AST value to format — a tuple, list, dict, ``int``,
        ``float``, ``str``, or ``None``.
    indent : int, default 0
        Current indentation level (each level = 2 spaces).

    Returns
    -------
    str
        A formatted, possibly multi-line string.

    Examples
    --------
    >>> from physika.utils.print_utils import _pformat
    >>> _pformat(("num", 3.0))
    "('num', 3.0)"
    >>> _pformat({"x": 1}, indent=1)
    '  {\\n    x: 1\\n  }'
    """
    prefix = "  " * indent

    if value is None:
        return f"{prefix}None"

    if isinstance(value, (int, float)):
        return f"{prefix}{value}"

    if isinstance(value, str):
        return f"{prefix}{repr(value)}"

    if isinstance(value, dict):
        if not value:
            return f"{prefix}{{}}"
        lines = [f"{prefix}{{"]
        for k, v in value.items():
            v_str = _pformat(v, indent + 2)
            # If the formatted value is a single line, put key: value on one
            # line
            v_stripped = v_str.strip()
            if "\n" not in v_stripped:
                lines.append(f"{prefix}  {k}: {v_stripped}")
            else:
                lines.append(f"{prefix}  {k}:")
                lines.append(v_str)
        lines.append(f"{prefix}}}")
        return "\n".join(lines)

    if isinstance(value, list):
        if not value:
            return f"{prefix}[]"
        # Check if all items are simple scalars
        if all(
                isinstance(v, (int, float, str)) and not isinstance(v, bool)
                for v in value):
            items = ", ".join(
                repr(v) if isinstance(v, str) else str(v) for v in value)
            oneline = f"{prefix}[{items}]"
            if len(oneline) <= 80:
                return oneline
        lines = [f"{prefix}["]
        for item in value:
            lines.append(f"{_pformat(item, indent + 1)},")
        lines.append(f"{prefix}]")
        return "\n".join(lines)

    if isinstance(value, tuple):
        if not value:
            return f"{prefix}()"
        # Check if all items are simple scalars
        if all(
                isinstance(v, (int, float, str)) and not isinstance(v, bool)
                for v in value):
            items = ", ".join(
                repr(v) if isinstance(v, str) else str(v) for v in value)
            oneline = f"{prefix}({items})"
            if len(oneline) <= 80:
                return oneline
        lines = [f"{prefix}("]
        for item in value:
            lines.append(f"{_pformat(item, indent + 1)},")
        lines.append(f"{prefix})")
        return "\n".join(lines)

    return f"{prefix}{repr(value)}"






def _from_torch(v: Any) -> Any:
    """Convert a torch value to a plain Python value for display.

    Scalars are unwrapped via ``.item()``, tensors are converted to
    nested Python lists via ``.tolist()``, and complex values whose
    imaginary part is negligible (``< 1e-10``) are reduced to their
    real part.

    Parameters
    ----------
    v : Any
        The value to convert. A ``torch.Tensor``, ``complex``,
        or any other Python value (returned as-is).

    Returns
    -------
    Any
        A plain Python scalar (``int``, ``float``, ``complex``) or
        nested ``list``.

    Examples
    --------
    >>> from physika.utils.print_utils import _from_torch
    >>> _from_torch(torch.tensor(3.0))
    3.0
    >>> _from_torch(torch.tensor([1.0, 2.0]))
    [1.0, 2.0]
    >>> _from_torch(complex(2.0, 0.0))
    2.0
    """
    if not isinstance(v, torch.Tensor):
        if isinstance(v, complex):
            if abs(v.imag) < 1e-10:
                return v.real
            return v
        return v
    if v.numel() == 1:
        val = v.item()
        if isinstance(val, complex) and abs(val.imag) < 1e-10:
            return val.real
        return val
    return v.detach().tolist()


def _infer_type(v: Any) -> str:
    """Infer the Physika type string for a value.

    Maps Python / PyTorch values to their Physika type notation:
    ``"ℝ"`` for real scalars, ``"ℂ"`` for complex, ``"ℝ[n]"`` for
    vectors, ``"ℝ[m,n]"`` for matrices, and the class name for
    ``nn.Module`` subclasses.

    Parameters
    ----------
    v : Any
        The value whose type to infer — ``torch.Tensor``, ``int``,
        ``float``, ``complex``, ``list``, or ``nn.Module``.

    Returns
    -------
    str
        A Physika type string (e.g. ``"ℝ"``, ``"ℝ[3]"``,
        ``"ℝ[2,3]"``, ``"ℂ"``).

    Examples
    --------
    >>> from physika.utils.print_utils import _infer_type
    >>> _infer_type(3.0)
    'ℝ'
    >>> _infer_type(torch.tensor([1.0, 2.0, 3.0]))
    'ℝ[3]'
    >>> _infer_type(complex(1, 2))
    'ℂ'
    >>> _infer_type(torch.tensor([1j, 2j, 3j]))
    'ℂ[3]'
    """
    if isinstance(v, complex):
        if v.imag == 0:
            return "ℝ"
        return "ℂ"
    if isinstance(v, torch.Tensor) and v.is_complex():
        if v.numel() == 1:
            return "ℂ"
        if v.dim() == 1:
            return f"ℂ[{v.shape[0]}]"
        dims = ",".join(str(d) for d in v.shape)
        return f"ℂ[{dims}]"
    if isinstance(v, torch.Tensor):
        if v.numel() == 1:
            return "ℝ"
        if v.dim() == 1:
            return f"ℝ[{v.shape[0]}]"
        dims = ",".join(str(d) for d in v.shape)
        return f"ℝ[{dims}]"
    if isinstance(v, (int, float)):
        return "ℝ"
    if isinstance(v, list):
        shape = []
        current = v
        while isinstance(current, list) and len(current) > 0:
            shape.append(len(current))
            current = current[0]
        return f"ℝ[{','.join(str(d) for d in shape)}]"
    if isinstance(v, nn.Module):
        return type(v).__name__
    return str(type(v).__name__)