Skip to content

Runner

runner

MojoGenerator

Bases: Protocol

Definition of a function that generates a MojoModel model instance.

MojoRuntime

Bases: Protocol

Definition of a function that executes a generated MojoModel model.

MojoObjective

Bases: Protocol

Definition of a function that scores a completed Mojo simulation.

BaseConfig

Bases: MojoBaseModel

n_proc class-attribute instance-attribute

Python
n_proc: int = DEFAULT_N_PROC

Number of proccesses to allow.

This value is used to determine how many parallel jobs can be run. It is also used for the discovery of trial status. Using a value of 1 will result in the slowest runtime, but highest reliability.

Important

Be a good citizen. Use a reasonable number if you are working on a shared resource. You are a jerk if you use everything.

resume class-attribute instance-attribute

Python
resume: bool = DEFAULT_RESUME

Whether to resume a study if the study already exists.

padding_style property

Python
padding_style: str

This dynamically defines the padding style for trial numbers. This is helpful to ensure the filesystem consistently sorts the trials.

Examples: - Suppose you n_trials is 2000 (and the nominal trial_num is 0) - This method would return 04d - Trial number 0 maps to 0000 - Trial number 123 maps to 0123 - Trial number 1999 will still map to 1999

MonteCarloConfig

Bases: BaseConfig

n_trial class-attribute instance-attribute

Python
n_trial: int = DEFAULT_MC_N_TRIAL

Number of trials to run.

You are able to resume a previous job and modify the number of runs desired by changing this value. A job already in progress will not be dynamically stopped though if you change this value at runtime.

n_proc class-attribute instance-attribute

Python
n_proc: int = DEFAULT_N_PROC

Number of proccesses to allow.

This value is used to determine how many parallel jobs can be run. It is also used for the discovery of trial status. Using a value of 1 will result in the slowest runtime, but highest reliability.

Important

Be a good citizen. Use a reasonable number if you are working on a shared resource. You are a jerk if you use everything.

resume class-attribute instance-attribute

Python
resume: bool = DEFAULT_RESUME

Whether to resume a study if the study already exists.

padding_style property

Python
padding_style: str

This dynamically defines the padding style for trial numbers. This is helpful to ensure the filesystem consistently sorts the trials.

Examples: - Suppose you n_trials is 2000 (and the nominal trial_num is 0) - This method would return 04d - Trial number 0 maps to 0000 - Trial number 123 maps to 0123 - Trial number 1999 will still map to 1999

OptimizerConfig

Bases: BaseConfig

n_trial class-attribute instance-attribute

Python
n_trial: int = DEFAULT_OP_N_TRIAL

Number of trials to run.

study_name class-attribute instance-attribute

Python
study_name: str = DEFAULT_OP_STUDY_NAME

Unique identifier for the Optuna study.

direction instance-attribute

Python
direction: Literal['minimize', 'maximize']

Whether we want to find the lowest or highest objective value.

timeout class-attribute instance-attribute

Python
timeout: float | None = DEFAULT_OP_TIMEOUT

Stop the study after this many seconds, regardless of trial count.

storage class-attribute instance-attribute

Python
storage: str | None = DEFAULT_OP_STORAGE

Database URL (e.g., 'sqlite:///study.db') for multi-node persistence.

sampler class-attribute instance-attribute

Python
sampler: SamplerOptions = DEFAULT_OP_SAMPLER

The search algorithm. TPE is generally best for noisy physics.

evals_per_trial class-attribute instance-attribute

Python
evals_per_trial: int = Field(
    default=DEFAULT_OP_EVALS_PER_TRIAL, ge=1
)

Number of times to run the sim with different seeds per trial. The average score is returned to Optuna.

Reduces 'lucky' trials in noisy physics.

refine_search_factor class-attribute instance-attribute

Python
refine_search_factor: float | None = Field(
    default=DEFAULT_OP_REFINE_SEARCH_FACTOR, gt=0, lt=1
)

If set (e.g., 0.5), and resume is True, the Runner will shrink the search space bounds by this factor around the current best trial to focus on local refinement.

Smaller values will more aggressively refine the search space.

prune_failed_trials class-attribute instance-attribute

Python
prune_failed_trials: bool = DEFAULT_OP_PRUNE_FAILED_TRIALS

Whether to immediately stop trials that violate physical constraints (e.g., Mujoco instability) to save compute time.

n_proc class-attribute instance-attribute

Python
n_proc: int = DEFAULT_N_PROC

Number of proccesses to allow.

This value is used to determine how many parallel jobs can be run. It is also used for the discovery of trial status. Using a value of 1 will result in the slowest runtime, but highest reliability.

Important

Be a good citizen. Use a reasonable number if you are working on a shared resource. You are a jerk if you use everything.

resume class-attribute instance-attribute

Python
resume: bool = DEFAULT_RESUME

Whether to resume a study if the study already exists.

padding_style property

Python
padding_style: str

This dynamically defines the padding style for trial numbers. This is helpful to ensure the filesystem consistently sorts the trials.

Examples: - Suppose you n_trials is 2000 (and the nominal trial_num is 0) - This method would return 04d - Trial number 0 maps to 0000 - Trial number 123 maps to 0123 - Trial number 1999 will still map to 1999

Trial dataclass

Python
Trial(
    trial_num: int,
    base_dir: Path,
    xml_name: str,
    model_config_name: str,
    padding_style: str,
)

Handles the lifecycle of a single simulation run.

The Trial object is responsible for the 'dirty work' of a Monte Carlo run: - Creating directories - Writing the MJCF XML - Saving the configuration snapshot - Triggering the physics runtime.

trial_num instance-attribute

Python
trial_num: int

Unique identifier for this trial iteration.

base_dir instance-attribute

Python
base_dir: Path

Root directory where all simulation trials are stored.

xml_name instance-attribute

Python
xml_name: str

Filename for the generated MJCF XML (e.g., 'model.xml').

model_config_name instance-attribute

Python
model_config_name: str

Filename for the serialized MojoModel configuration (e.g., 'config.json').

padding_style instance-attribute

Python
padding_style: str

Format specifier for directory naming (e.g., '04d').

trial_dir property

Python
trial_dir: Path

The absolute path to this trial's unique workspace.

Example

If base_dir is './sims' and trial_num is 7 with '03d' padding, this returns './sims/trial_007'.

xml_path property

Python
xml_path: Path

The full path to the MJCF XML file for this trial.

model_config_path property

Python
model_config_path: Path

The full path to the JSON configuration file for this trial.

named_value_path property

Python
named_value_path: Path

The full path to the JSON NamedValue file for this trial.

run

Python
run(
    generator: MojoGenerator,
    runtime: MojoRuntime | None,
    seed: int | None,
    overrides: NamedValueDict[NDArray],
    gen_args: list[Any],
    gen_kwargs: dict[str, Any],
    run_args: list[Any],
    run_kwargs: dict[str, Any],
) -> tuple[
    MojoModel | None,
    TrialStatus,
    MjModel | None,
    MjData | None,
]

Executes the complete simulation pipeline for this trial.

This method coordinates three main phases: 1. Generation: Calls the user-provided generator to build a MojoModel model. 2. Persistence: Creates the workspace and writes the model/config to disk. 3. Execution: Triggers the physics runtime if one is provided.

Parameters:

Name Type Description Default
generator MojoGenerator

Function that returns a MojoModel instance.

required
runtime MojoRuntime | None

Optional function to run the simulation (MuJoCo).

required
seed int | None

Seed to use to define the trial.

required
overrides NamedValueDict[NDArray]

Key-value pairs that override random distributions.

required
gen_args list[Any]

Positional arguments for the generator.

required
gen_kwargs dict[str, Any]

Keyword arguments for the generator.

required
run_args list[Any]

Positional arguments for the runtime.

required
run_kwargs dict[str, Any]

Keyword arguments for the runtime.

required

Returns:

Type Description
tuple[MojoModel | None, TrialStatus, MjModel | None, MjData | None]

The MojoModel object for the trial or None if there was a failure prior to generating the MojoModel and the status of the trial.

Source code in src/mujoco_mojo/utils/runner.py
Python
def run(
    self,
    generator: MojoGenerator,
    runtime: MojoRuntime | None,
    seed: int | None,
    overrides: NamedValueDict[NDArray],
    gen_args: list[Any],
    gen_kwargs: dict[str, Any],
    run_args: list[Any],
    run_kwargs: dict[str, Any],
) -> tuple[
    MojoModel | None, TrialStatus, mujoco.MjModel | None, mujoco.MjData | None
]:
    """
    Executes the complete simulation pipeline for this trial.

    This method coordinates three main phases:
    1.  **Generation**: Calls the user-provided generator to build a `MojoModel` model.
    2.  **Persistence**: Creates the workspace and writes the model/config to disk.
    3.  **Execution**: Triggers the physics runtime if one is provided.

    Args:
        generator: Function that returns a `MojoModel` instance.
        runtime: Optional function to run the simulation (MuJoCo).
        seed: Seed to use to define the trial.
        overrides: Key-value pairs that override random distributions.
        gen_args: Positional arguments for the generator.
        gen_kwargs: Keyword arguments for the generator.
        run_args: Positional arguments for the runtime.
        run_kwargs: Keyword arguments for the runtime.

    Returns:
        The `MojoModel` object for the trial or None if there was a failure prior to generating the MojoModel and the status of the trial.

    """
    status = TrialStatus(trial_num=self.trial_num)
    status._path = self.trial_dir / TRIAL_STATUS_FNAME

    with status.record_step(step_name="pending"):
        pass

    result = None
    mj_model = None
    mj_data = None
    try:
        # 1. Generate
        with status.record_step(step_name="generating"):
            logger.info(f"Generating trial_num={self.trial_num}")
            mojo_model = (
                MojoModel()
                .with_overrides(overrides=overrides)
                .with_seed(seed=seed)
                .with_trial_num(self.trial_num)
            )
            mojo_model = generator(mojo_model, overrides, *gen_args, **gen_kwargs)

            # 2. Setup Workspace & Save Metadata
            logger.info(f"Saving trial_num={self.trial_num} to {self.trial_dir}")
            self.trial_dir.mkdir(parents=True, exist_ok=True)

            # bundle assets, this remaps DepPath attributes to point to the shared asset dir
            rel_to_xml = Path(
                os.path.relpath(self.shared_asset_dir, self.trial_dir)
            )
            mojo_model.mjcf.bundle_assets(
                target_dir=self.shared_asset_dir, rel_to_xml=rel_to_xml
            )

            # save XML (with modified DepPath)
            if runtime is None:
                mojo_model.mjcf.write_xml(self.xml_path)
            mojo_model.dump_to_path(self.model_config_path)
            self.named_value_path.write_text(mojo_model.named.model_dump_json())

        with status.record_step(step_name="solving"):
            # 3. Execute (if runtime provided)
            if runtime is not None:
                logger.info(f"Executing trial_num={self.trial_num} runtime")
                import mujoco_mojo.runtime as rt

                runtime_manager = rt.RuntimeManager(
                    signal_manager=rt.SignalManager(
                        export_path=self.trial_dir
                        / rt.SignalManager.default_output_name()
                    )
                )
                mj_model, mj_data = mojo_model.mjcf.prep_for_sim(self.xml_path)
                result = runtime(
                    mojo_model,
                    runtime_manager,
                    mj_model,
                    mj_data,
                    *run_args,
                    **run_kwargs,
                )
            else:
                logger.info(
                    f"No runtime definition was provided for trial_num={self.trial_num} so MuJoCo will not be run."
                )
                result = mojo_model
                mj_model = None
                mj_data = None

        # serialize again in case new named values were added during the run
        mojo_model.dump_to_path(self.model_config_path)
        self.named_value_path.write_text(mojo_model.named.model_dump_json())
        status.step = "done"
        status.completion = Completion.SUCCESS

    except (BdbQuit, KeyboardInterrupt):
        logger.warning("Quit command detected. Exiting execution...")
        raise
    except Exception as e:
        status.step = "done"
        status.completion = Completion.FAILED
        logger.exception(
            f"Trial {self.trial_num} failed with the following error: {e}"
        )
    finally:
        status.dump_to_path(status._path)

    return result, status, mj_model, mj_data

MojoRunner dataclass

Python
MojoRunner(
    generator: MojoGenerator,
    generator_path: str | None = None,
    runtime: MojoRuntime | None = DEFAULT_RUNTIME,
    runtime_path: str | None = None,
    objective: MojoObjective | None = None,
    objective_path: str | None = None,
    seed: int | None = DEFAULT_SEED,
    workdir: Path = DEFAULT_WORKDIR,
    model_config_name: str = DEFAULT_MODEL_CONFIG_NAME,
    xml_name: str = DEFAULT_XML_NAME,
    config: MonteCarloConfig
    | OptimizerConfig = MonteCarloConfig(),
    gen_args: list[Any] = list(),
    gen_kwargs: dict[str, Any] = dict(),
    run_args: list[Any] = list(),
    run_kwargs: dict[str, Any] = dict(),
)

slurm_trial_id property

Python
slurm_trial_id: int | None

Returns the current SLURM task ID if running as part of an array job.

run

Python
run(
    global_overrides: NamedValueDict[
        NDArray
    ] = NamedValueDict[NDArray](),
    clean_workdir: bool = False,
    cleanup_delay: int = 10,
    execution_mode: ExecutionMode = LOCAL,
    trial_nums: list[int] | None = None,
) -> bool

Vectors a job to be either computed locally or to be orchestrated by SLURM.

Source code in src/mujoco_mojo/utils/runner.py
Python
def run(
    self,
    global_overrides: NamedValueDict[NDArray] = NamedValueDict[NDArray](),
    clean_workdir: bool = False,
    cleanup_delay: int = 10,
    execution_mode: ExecutionMode = ExecutionMode.LOCAL,
    trial_nums: list[int] | None = None,
) -> bool:
    """Vectors a job to be either computed locally or to be orchestrated by SLURM."""
    if clean_workdir:
        if self.config.resume:
            msg = "clean_workdir and resume are mutually exclusive with one another. Use one or the other."
            logger.error(msg)
            raise ValueError(msg)

        self.force_remove_dir(countdown_from=cleanup_delay, path=self.workdir)

    self.workdir.mkdir(parents=True, exist_ok=True)
    if not (self.workdir / ".gitignore").exists():
        (self.workdir / ".gitignore").write_text("*", encoding="utf-8")
    self.capture_environment()

    match execution_mode:
        case ExecutionMode.LOCAL:
            return self.run_local(
                global_overrides=global_overrides,
                trial_nums=trial_nums,
            )
        case ExecutionMode.SLURM:
            return self.orchestrate_slurm(
                global_overrides=global_overrides,
                trial_nums=trial_nums,
            )
        case _:
            msg = f"No run command has been configured for execution mode {execution_mode}"
            logger.error(msg)
            raise NotImplementedError(msg)

orchestrate_slurm

Python
orchestrate_slurm(
    global_overrides: NamedValueDict[NDArray],
    trial_nums: list[int] | None = None,
) -> bool

Generates an sbatch script and submits the job array to SLURM for a given config.

Source code in src/mujoco_mojo/utils/runner.py
Python
def orchestrate_slurm(
    self,
    global_overrides: NamedValueDict[NDArray],
    trial_nums: list[int] | None = None,
) -> bool:
    """Generates an sbatch script and submits the job array to SLURM for a given config."""
    (self.workdir / "logs").mkdir(exist_ok=True)

    if isinstance(self.config, MonteCarloConfig):
        try:
            had_fails = self.orchestrate_slurm_monte_carlo(
                global_overrides=global_overrides,
                trial_nums=trial_nums,
            )
        except (BdbQuit, KeyboardInterrupt):
            print("\n")
            logger.error("Aborted SLURM orchestration!")
            had_fails = True
    else:
        msg = f"A SLURM configuration for {self.config.__class__.__name__} has not been implemented."
        logger.error(msg)
        raise NotImplementedError(msg)
    return had_fails

execute_single_trial

Python
execute_single_trial(
    trial_num: int,
    seed: int | None,
    overrides_payload: dict,
) -> tuple[
    MojoModel | None,
    TrialStatus,
    MjModel | None,
    MjData | None,
]

Helper to package a Trial and run it.

Source code in src/mujoco_mojo/utils/runner.py
Python
def execute_single_trial(
    self, trial_num: int, seed: int | None, overrides_payload: dict
) -> tuple[
    MojoModel | None, TrialStatus, mujoco.MjModel | None, mujoco.MjData | None
]:
    """Helper to package a Trial and run it."""
    overrides = NamedValueDict[NDArray].model_validate(overrides_payload)

    trial = Trial(
        trial_num=trial_num,
        base_dir=self.workdir,
        xml_name=self.xml_name,
        model_config_name=self.model_config_name,
        padding_style=self.config.padding_style,
    )

    return trial.run(
        generator=self.generator,
        runtime=self.runtime,
        seed=seed,
        overrides=overrides,
        gen_args=self.gen_args,
        gen_kwargs=self.gen_kwargs,
        run_args=self.run_args,
        run_kwargs=self.run_kwargs,
    )

run_monte_carlo

Python
run_monte_carlo(
    global_overrides: NamedValueDict[
        NDArray
    ] = NamedValueDict[NDArray](),
    trial_nums: list[int] | None = None,
) -> bool

Orchestrates a Monte Carlo job.

Source code in src/mujoco_mojo/utils/runner.py
Python
def run_monte_carlo(
    self,
    global_overrides: NamedValueDict[NDArray] = NamedValueDict[NDArray](),
    trial_nums: list[int] | None = None,
) -> bool:
    """Orchestrates a Monte Carlo job."""
    if self.slurm_trial_id is not None:
        tn = self.slurm_trial_id
        logger.info(f"SLURM Worker detected. Executing Trial {tn}")
        _mojo_model, trial_status, _, __ = self.execute_single_trial(
            trial_num=tn,
            seed=self.seed,
            overrides_payload=global_overrides.model_dump(),
        )

        return trial_status.completion == Completion.FAILED

    job_trial_nums = trial_nums if trial_nums else self.config.trial_nums

    # initialize the status tracker
    status_tracker = JobStatus(
        workdir=self.workdir.resolve(),
        job_type=JobType.MONTE_CARLO,
        execution_mode=ExecutionMode.LOCAL,
        n_proc=self.config.n_proc,
        seed=self.seed,
        padding_style=self.config.padding_style,
        generator=MojoRunner.inspect_protocol(self.generator),
        runtime=MojoRunner.inspect_protocol(self.runtime),
        objective=MojoRunner.inspect_protocol(self.objective),
        gen_args_used=bool(self.gen_args),
        gen_kwargs_used=bool(self.gen_kwargs),
        run_args_used=bool(self.run_args),
        run_kwargs_used=bool(self.run_kwargs),
        trial_nums=job_trial_nums,
    )

    # decide which trials to execute
    if self.config.resume:
        self._renumber_trial_folders(
            self.workdir.resolve(), self.config.padding_style
        )
        status_tracker.refresh_from_disk(n_proc=self.config.n_proc)
    status_tracker.dump_to_path(self.workdir / JOB_STATUS_FNAME)

    to_run = status_tracker.pending_trial_nums

    if not to_run:
        logger.info("All trials were already completed. Nothing to do.")
        return bool(status_tracker.failed_trial_nums)

    if self.config.is_parallel:
        logger.info(
            f"Running {len(to_run)} trials with {self.config.n_proc} processors. {status_tracker.n_done}/{self.config.n_trial} ({status_tracker.progress:.2%}) trials completed."
        )
        # needed for logging on Windows
        with multiprocessing.Manager() as m:
            log_queue = m.Queue()

            parent_log_level = logging.getLogger().getEffectiveLevel()
            listener = QueueListener(log_queue, *logging.getLogger().handlers)
            listener.start()

            executor = ProcessPoolExecutor(
                max_workers=self.config.n_proc,
                initializer=worker_init,
                initargs=(log_queue, parent_log_level),
            )
            try:
                future_to_tn = {
                    executor.submit(
                        self.execute_single_trial,
                        tn,
                        self.seed,
                        global_overrides.model_dump(),
                    ): tn
                    for tn in to_run
                }
                for f in as_completed(future_to_tn):
                    tn = future_to_tn[f]
                    try:
                        _mojo_model, trial_status, _, __ = f.result()
                        status_tracker.update_trial(status=trial_status)
                    except (BdbQuit, KeyboardInterrupt):
                        # user is quitting from breakpoint() or CTRL+C
                        raise
                    except Exception as e:
                        logger.exception(f"Trial {tn} failed: {e}")
                        status_tracker.update_trial(
                            status=TrialStatus(
                                trial_num=tn, completion=Completion.FAILED
                            )
                        )
                    status_tracker.generate_report()
            except (BdbQuit, KeyboardInterrupt):
                # allows killing the job with one CTRL+C
                logger.warning("Interrupt recieved. Stopping all trials.")
                executor.shutdown(wait=False, cancel_futures=True)
                raise
            finally:
                listener.stop()
                executor.shutdown(wait=True)
    else:
        for tn in to_run:
            try:
                _mojo_model, trial_status, _, __ = self.execute_single_trial(
                    trial_num=tn,
                    seed=self.seed,
                    overrides_payload=global_overrides.model_dump(),
                )
                status_tracker.update_trial(
                    status=trial_status,
                )
            except (BdbQuit, KeyboardInterrupt):
                # user is quitting from breakpoint() or CTRL+C
                raise
            except Exception as e:
                logger.exception(f"A trial failed with error: {e}")
                status_tracker.update_trial(
                    status=TrialStatus(trial_num=tn, completion=Completion.FAILED)
                )
            status_tracker.generate_report()

    status_tracker.generate_report(alert_generation=True)
    return bool(status_tracker.failed_trial_nums)

get_slurm_array_string staticmethod

Python
get_slurm_array_string(ids: list[int]) -> str

Collapses [0, 1, 2, 5, 6] into '0-2,5-6' for SLURM.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_slurm_array_string(ids: list[int]) -> str:
    """Collapses [0, 1, 2, 5, 6] into '0-2,5-6' for SLURM."""
    if not ids:
        return ""

    ranges = []
    # Identify groups of consecutive integers
    for _, group in itertools.groupby(
        enumerate(sorted(ids)), lambda x: x[1] - x[0]
    ):
        group = list(group)
        start = group[0][1]
        end = group[-1][1]
        if start == end:
            ranges.append(str(start))
        else:
            ranges.append(f"{start}-{end}")

    return ",".join(ranges)

get_slurm_partitions staticmethod

Python
get_slurm_partitions() -> tuple[list[str], str | None]

Queries sinfo for available partitions and identifies the default.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_slurm_partitions() -> tuple[list[str], str | None]:
    """Queries sinfo for available partitions and identifies the default."""
    try:
        result = subprocess.run(
            ["sinfo", "-h", "--format=%P"],
            capture_output=True,
            text=True,
            timeout=5,
        )
        if result.returncode == 0:
            raw_partitions = [
                p.strip() for p in result.stdout.splitlines() if p.strip()
            ]

            default_partition = None
            clean_partitions = []

            for p in raw_partitions:
                if p.endswith("*"):
                    name = p.replace("*", "")
                    default_partition = name
                    clean_partitions.append(name)
                else:
                    clean_partitions.append(p)

            return sorted(list(set(clean_partitions))), default_partition
    except Exception:
        pass
    return [], None

normalize_to_mb staticmethod

Python
normalize_to_mb(mem_str: str) -> int

Converts SLURM memory strings (e.g., '1000', '1G', '1024M') to integer MB.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def normalize_to_mb(mem_str: str) -> int:
    """Converts SLURM memory strings (e.g., '1000', '1G', '1024M') to integer MB."""
    # Split the number from the unit (e.g., '1024M' -> '1024', 'M')
    match = re.match(r"(\d+)([KMGTP]?)", mem_str.upper())
    if not match:
        return 0

    value, unit = match.groups()
    value = int(value)

    # SLURM units are powers of 1024
    multipliers = {
        "K": 1 / 1024,  # Kilobytes to MB
        "M": 1,  # Megabytes
        "G": 1024,  # Gigabytes to MB
        "T": 1024**2,  # Terabytes to MB
        "P": 1024**3,  # Petabytes to MB
    }

    return int(value * multipliers.get(unit or "M", 1))

format_bytes staticmethod

Python
format_bytes(mb_value: int) -> str

Scales MB back up to the most readable unit (G, T, etc.).

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def format_bytes(mb_value: int) -> str:
    """Scales MB back up to the most readable unit (G, T, etc.)."""
    units = ["M", "G", "T", "P"]
    value = float(mb_value)
    unit_index = 0

    # Keep dividing by 1024 as long as it's a clean multiple
    while value >= 1024 and unit_index < len(units) - 1:
        value /= 1024
        unit_index += 1

    # If it's a whole number (like 1.0G), show it as 1G.
    # Otherwise, show one decimal place (like 1.5G).
    if value.is_integer():
        return f"{int(value)}{units[unit_index]}"
    return f"{value:.1f}{units[unit_index]}"

get_slurm_node_mem_limit staticmethod

Python
get_slurm_node_mem_limit(partition_name: str) -> str

Finds the MINIMUM RealMemory limit, normalized to MB.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_slurm_node_mem_limit(partition_name: str) -> str:
    """Finds the MINIMUM RealMemory limit, normalized to MB."""
    try:
        # 1. Get nodes from partition
        part_info = subprocess.run(
            ["scontrol", "show", "partition", partition_name, "-o"],
            capture_output=True,
            text=True,
        ).stdout
        node_match = re.search(r"\bNodes=(\S+)", part_info)
        if not node_match:
            return "<UNKNOWN>"

        # 2. Get node info (can handle ranges like c[1-2])
        node_info = subprocess.run(
            ["scontrol", "show", "node", node_match.group(1), "-o"],
            capture_output=True,
            text=True,
        ).stdout

        # 3. Find all RealMemory values (capturing optional suffixes)
        # Regex captures digits and any trailing letters (like G or M)
        mem_matches = re.findall(r"\bRealMemory=(\d+[KMGTP]?)", node_info)
        if mem_matches:
            # Normalize every match to MB and find the lowest one
            min_mb = min(MojoRunner.normalize_to_mb(m) for m in mem_matches)
            return MojoRunner.format_bytes(min_mb)

    except Exception:
        pass

    return "<UNKNOWN>"

get_slurm_cpu_limit staticmethod

Python
get_slurm_cpu_limit(partition_name: str) -> str

Finds the minimum of the max allowed CPUs (CPUTot) among nodes in a partition.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_slurm_cpu_limit(partition_name: str) -> str:
    """Finds the minimum of the max allowed CPUs (CPUTot) among nodes in a partition."""
    try:
        # 1. Get the nodes in the partition
        part_cmd = ["scontrol", "show", "partition", partition_name, "-o"]
        part_info = subprocess.run(part_cmd, capture_output=True, text=True).stdout

        node_match = re.search(r"\bNodes=(\S+)", part_info)
        if not node_match:
            return "<UNKNOWN>"

        # 2. Get the node info
        node_cmd = ["scontrol", "show", "node", node_match.group(1), "-o"]
        nodes_info = subprocess.run(node_cmd, capture_output=True, text=True).stdout

        # 3. Find all CPUTot values (The total physical/logical CPUs on the node)
        cpu_values = re.findall(r"\bCPUTot=(\d+)", nodes_info)

        if cpu_values:
            # Find the minimum to ensure any node can handle the task
            return str(min(int(v) for v in cpu_values))

    except Exception:
        pass

    return "<UNKNOWN>"

slurm_time_to_seconds staticmethod

Python
slurm_time_to_seconds(time_str: str) -> int

Converts SLURM time (D-HH:MM:SS or HH:MM:SS) to total seconds.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def slurm_time_to_seconds(time_str: str) -> int:
    """Converts SLURM time (D-HH:MM:SS or HH:MM:SS) to total seconds."""
    if time_str.upper() == "UNLIMITED":
        return -1

    # Format: Days-Hours:Minutes:Seconds
    days = 0
    if "-" in time_str:
        days_part, time_str = time_str.split("-")
        days = int(days_part)

    parts = list(map(int, time_str.split(":")))
    if len(parts) == 3:  # HH:MM:SS
        return days * 86400 + parts[0] * 3600 + parts[1] * 60 + parts[2]
    if len(parts) == 2:  # MM:SS
        return days * 86400 + parts[0] * 60 + parts[1]
    if len(parts) == 1:  # MM
        return days * 86400 + parts[0] * 60
    return 0

get_slurm_time_limit staticmethod

Python
get_slurm_time_limit(partition_name: str) -> str

Finds the MaxTime limit for a specific partition.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_slurm_time_limit(partition_name: str) -> str:
    """Finds the MaxTime limit for a specific partition."""
    try:
        # Get partition info
        part_cmd = ["scontrol", "show", "partition", partition_name, "-o"]
        part_info = subprocess.run(part_cmd, capture_output=True, text=True).stdout

        # Look for MaxTime= followed by the time string (e.g., 1-00:00:00 or UNLIMITED)
        time_match = re.search(r"\bMaxTime=(\S+)", part_info)
        if time_match:
            return time_match.group(1)

    except Exception:
        pass

    return "<UNKNOWN>"

get_max_array_size staticmethod

Python
get_max_array_size() -> int

Queries the global SLURM configuration for MaxArraySize.

Source code in src/mujoco_mojo/utils/runner.py
Python
@staticmethod
def get_max_array_size() -> int:
    """Queries the global SLURM configuration for MaxArraySize."""
    try:
        result = subprocess.run(
            ["scontrol", "show", "config"], capture_output=True, text=True
        ).stdout
        match = re.search(r"MaxArraySize=(\d+)", result)
        return int(match.group(1)) if match else 1001
    except Exception:
        return 1001

orchestrate_slurm_monte_carlo

Python
orchestrate_slurm_monte_carlo(
    global_overrides: NamedValueDict[
        NDArray
    ] = NamedValueDict[NDArray](),
    trial_nums: list[int] | None = None,
) -> bool

Orchestrates a Monte Carlo SLURM submission.

Source code in src/mujoco_mojo/utils/runner.py
Python
    def orchestrate_slurm_monte_carlo(
        self,
        global_overrides: NamedValueDict[NDArray] = NamedValueDict[NDArray](),
        trial_nums: list[int] | None = None,
    ) -> bool:
        """Orchestrates a Monte Carlo SLURM submission."""
        from rich.console import Console
        from rich.prompt import Confirm, Prompt

        console = Console()

        project_root = Path.cwd().resolve()

        # persist overrides so workers can access them
        overrides_path = self.workdir.resolve() / "global_overrides.json"
        if len(global_overrides) > 0:
            logger.info(f"Persisting global overrides to {overrides_path}")
            overrides_path.write_text(global_overrides.model_dump_json())

        # initialize the status tracker
        job_trial_nums = trial_nums if trial_nums else self.config.trial_nums

        status_tracker = JobStatus(
            workdir=self.workdir.resolve(),
            job_type=JobType.MONTE_CARLO,
            execution_mode=ExecutionMode.SLURM,
            n_proc=self.config.n_proc,
            seed=self.seed,
            padding_style=self.config.padding_style,
            generator=MojoRunner.inspect_protocol(self.generator),
            runtime=MojoRunner.inspect_protocol(self.runtime),
            objective=MojoRunner.inspect_protocol(self.objective),
            gen_args_used=bool(self.gen_args),
            gen_kwargs_used=bool(self.gen_kwargs),
            run_args_used=bool(self.run_args),
            run_kwargs_used=bool(self.run_kwargs),
            trial_nums=job_trial_nums,
        )

        # decide which trials to execute
        if self.config.resume:
            self._renumber_trial_folders(
                self.workdir.resolve(), self.config.padding_style
            )
            status_tracker.refresh_from_disk(n_proc=self.config.n_proc)
        status_tracker.dump_to_path(self.workdir / JOB_STATUS_FNAME)

        to_run = status_tracker.pending_trial_nums

        if not to_run:
            logger.info("All trials were already completed. Nothing to do.")
            return False
        else:
            logger.info(f"{len(to_run)} trials were identified for running.")

        # reconstruct the CLI command for the worker
        gen_args_str = " ".join([f'--gen-arg "{a}"' for a in self.gen_args])
        gen_kwargs_str = " ".join(
            [f'--gen-kwarg "{k}={v}"' for k, v in self.gen_kwargs.items()]
        )
        run_args_str = " ".join([f'--run-arg "{a}"' for a in self.run_args])
        run_kwargs_str = " ".join(
            [f'--run-kwarg "{k}={v}"' for k, v in self.run_kwargs.items()]
        )

        runtime_flag = f'--runtime "{self.runtime_path}"' if self.runtime_path else ""
        seed_flag = f"--seed {self.seed}" if self.seed is not None else ""
        overrides_flag = (
            f'--overrides "{overrides_path}"' if len(global_overrides) > 0 else ""
        )

        # get the path to the mujoco-mojo CLI executable
        py_bin_dir = Path(sys.executable).parent.resolve()
        mojo_cmd = py_bin_dir / "mujoco-mojo"

        cmd = (
            f"{mojo_cmd} run monte-carlo "
            f'--generator "{self.generator_path}" '
            f"{runtime_flag} {seed_flag} {overrides_flag} "
            f'--workdir "{self.workdir.resolve()}" '
            f"{gen_args_str} {gen_kwargs_str} "
            f"{run_args_str} {run_kwargs_str} "
            f"--n-trials 0 "  # force to zero to prevent an expected warning
            f"--trial-num $SLURM_ARRAY_TASK_ID "  # execute its onw trial_num
            f"--execution-mode local "  # using local since slurm will just send us back to this method
            f"--n-proc 1"  # A worker only needs 1 process
        )

        # ask for sbatch settings with a bunch of console inputs with default values
        available_partitions, default_partition = self.get_slurm_partitions()

        console.print(
            "\n[bold cyan]MuJoCo Mojo Orchestrator: SLURM Resource Setup[/bold cyan]"
        )
        console.print(f"\t            [dim]Python:[/dim] {sys.executable}")
        console.print(f"\t              [dim]Root:[/dim] {project_root}")
        console.print(f"\t[dim]mujoco-mojo binary:[/dim] {mojo_cmd}\n")

        # Standard colors only for Rich compatibility
        job_name = Prompt.ask("  [white]Job Name[/]", default="mojo-sim")
        if available_partitions:
            # Use the actual SLURM default if we found one, otherwise the first in list
            initial_default = (
                default_partition if default_partition else available_partitions[0]
            )

            partition = Prompt.ask(
                "  [white]Partition[/]",
                choices=available_partitions,
                default=initial_default,
            )
        else:
            partition = Prompt.ask(
                "  [white]Partition[/] [dim](optional)[/]", default=""
            )

        # === get partition limits ===
        cpu_limit = self.get_slurm_cpu_limit(partition)
        mem_limit = self.get_slurm_node_mem_limit(partition)
        time_limit_hint = self.get_slurm_time_limit(partition)

        # === get cpus per task ===
        cpus_per_task = Prompt.ask(
            f"  [white]CPUs per task[/] [dim](Node Limit: {cpu_limit})[/]",
            default="1",
        )
        cpus_per_task = max([1, int(cpus_per_task)])
        if cpu_limit != "<UNKNOWN>" and cpus_per_task > int(cpu_limit):
            console.print(
                f"\n[bold red]WARNING:[/] Requested CPUs ({cpus_per_task}) exceeds "
                f"the physical node limit ({cpu_limit})."
            )
            if not Confirm.ask("Do you want to proceed anyway?", default=False):
                return True

        # === get memory ===
        mem_per_node = Prompt.ask(
            f"  [white]Memory per node[/] (e.g., 256M) [dim](Node Limit: {mem_limit})[/]",
            default=mem_limit,
        )
        if self.normalize_to_mb(mem_limit) > 0 and self.normalize_to_mb(
            mem_per_node
        ) > self.normalize_to_mb(mem_limit):
            console.print(
                f"\n[bold red]WARNING:[/] Requested memory ({mem_per_node}) exceeds "
                f"the partition node limit ({mem_limit})."
            )
            console.print("[red]This job will likely be rejected by SLURM.[/]\n")
            if not Confirm.ask("Do you want to proceed anyway?", default=False):
                return True

        # === get time ===
        time_limit = Prompt.ask(
            f"  [white]Time limit[/] (HH:MM:SS) [dim](Partition Limit: {time_limit_hint})[/]",
            default="01:00:00",
        )
        requested_seconds = self.slurm_time_to_seconds(time_limit)
        max_seconds = self.slurm_time_to_seconds(time_limit_hint)
        if requested_seconds > max_seconds and max_seconds != -1:  # -1 means infinite
            console.print(
                f"\n[bold red]WARNING:[/] Requested time ({time_limit}) exceeds "
                f"partition MaxTime ({time_limit_hint})."
            )
            if not Confirm.ask("Proceed anyway?", default=False):
                return True

        current_pythonpath = os.getenv("PYTHONPATH", "")
        if str(project_root) not in current_pythonpath:
            console.print(
                f"\n[yellow]Warning:[/] Project root [italic]{project_root}[/] is not in your PYTHONPATH."
            )
            if Confirm.ask(
                "Should Mojo automatically include it in the SLURM submission?",
                default=True,
            ):
                # We'll handle this in the sbatch content generation
                include_root_in_path = True
            else:
                include_root_in_path = False
        else:
            include_root_in_path = True

        partition_line = f"#SBATCH --partition={partition}" if partition else ""
        python_path_line = (
            f"export PYTHONPATH=$PYTHONPATH:{project_root}"
            if include_root_in_path
            else ""
        )

        # generate the .sh script
        max_array = self.get_max_array_size()
        if len(to_run) > max_array:
            console.print(
                f"\n[bold red]ERROR:[/] Trial count ({len(to_run)}) exceeds "
                f"SLURM MaxArraySize ({max_array})."
            )
            if not Confirm.ask("Proceed anyway?", default=False):
                return True
        array_range = self.get_slurm_array_string(to_run)
        script_path = (self.workdir / "mujoco_mojo_submit.sh").resolve()

        sbatch_content = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --array={array_range}
#SBATCH --output={self.workdir.resolve()}/logs/trial_%a.log
#SBATCH --cpus-per-task={cpus_per_task}
#SBATCH --mem={mem_per_node}
#SBATCH --time={time_limit}
{partition_line}

# Move to the project root so imports work
cd {project_root}
{python_path_line}

# Execute the worker command
{cmd}
"""

        script_path.write_text(sbatch_content, encoding="utf-8")
        logger.info(f"SLURM submission script written to {script_path}")

        # final submission
        if Confirm.ask(
            f"\n[cyan]Submit {len(to_run)} trials to SLURM now?[/]", default=True
        ):
            # automatic submission
            logger.info(f"Submitting {len(to_run)} trials...")
            result = subprocess.run(
                ["sbatch", str(script_path)], capture_output=True, text=True
            )
            if result.returncode == 0:
                job_id_msg = result.stdout.strip()
                # Extract just the numeric ID if possible (e.g. "Submitted batch job 2" -> "2")
                job_id = job_id_msg.split()[-1] if job_id_msg else "UNKNOWN"

                console.print(f"\n[bold green]Success![/] {job_id_msg}")

                # === Monitoring Dashboard ===
                console.print("\n[bold cyan]Monitoring Status:[/bold cyan]")
                console.print(
                    f"  - [white]Check status:[/]       [green]squeue -j {job_id}[/]"
                )
                console.print(
                    f"  - [white]Watch live:[/]         [green]watch -n 1 squeue -j {job_id}[/]"
                )
                console.print(
                    f"  - [white]View first log:[/]     [green]tail -f {self.workdir.resolve()}/logs/trial_0.log[/]"
                )
                console.print(
                    f"  - [white]Cancel all trials:[/]  [green]scancel {job_id}[/]"
                )

                console.print(
                    f"\n[dim]Logs are being written to: {self.workdir.resolve()}/logs/[/]"
                )
                return False
            else:
                logger.error(f"SLURM Submission Failed: {result.stderr}")
                return True
        else:
            # deffered submission
            console.print(
                f"\n[yellow]Orchestration complete.[/] Submit manually with:\n[bold green]sbatch {script_path}[/]"
            )
            return False