import json
import sys
import types
import click
from .. import echo
from . import ersilia_cli
[docs]
def run_cmd():
"""
Runs a specified model.
This command allows users to run a specified model with given inputs.
Returns
-------
function
The run command function to be used by the CLI and for testing in the pytest.
Examples
--------
.. code-block:: console
Run a model by its ID with input data:
$ ersilia run -i <input_data> --as-table
Run a model with batch size:
$ ersilia run -i <input_data> -b 50
"""
def validate_input_output_types(input, output):
from ...utils.terminal import is_quoted_list
if (type(input) == str and not input.endswith(".csv")) or is_quoted_list(
json.dumps(input)
):
echo(
"Input must be a single-column CSV file. String and list inputs are not supported.",
fg="red",
bold=True,
)
sys.exit(1)
if output is not None and not any(
[output.endswith(ext) for ext in (".csv", ".h5")]
):
echo(
"This output type is not allowed in Ersilia. Valid output types are .csv or .h5",
fg="red",
bold=True,
)
sys.exit(1)
if output is None:
echo(
"Please specify a valid output file with extension .csv or .h5",
fg="red",
bold=True,
)
sys.exit(1)
# Example usage: ersilia run -i {INPUT} [-o {OUTPUT} -b {BATCH_SIZE}]
@ersilia_cli.command(
short_help="Run predictions on the served model",
help="Run predictions using the currently served model. Input must be a single-column CSV file. Output can be saved as .csv or .h5. A model must be served before running this command.",
)
@click.option(
"-i",
"--input",
"input",
required=True,
type=click.STRING,
help="Path to a single-column CSV file containing the input data.",
)
@click.option(
"-o",
"--output",
"output",
required=True,
default=None,
type=click.STRING,
help="Path to the output file. Accepted formats: .csv, .h5.",
)
@click.option(
"-b",
"--batch_size",
"batch_size",
required=False,
default=100,
type=click.INT,
help="Number of inputs processed per batch.",
)
def run(input, output, batch_size):
import os
import re
from ... import ErsiliaModel
from ...core.session import Session
validate_input_output_types(input, output)
session = Session(config_json=None)
model_id = session.current_model_id()
service_class = session.current_service_class()
output_source = session.current_output_source()
if model_id is None:
echo(
"No model seems to be served. Please run 'ersilia serve ...' before.",
fg="red",
)
return
output_basename = os.path.basename(output)
output_model_ids = re.findall(r"eos[0-9][a-z0-9]{3}", output_basename)
if output_model_ids and output_model_ids[0] != model_id:
echo(
f"Output filename contains model identifier '{output_model_ids[0]}' but the served model is '{model_id}'. Please use a correct output filename.",
fg="red",
bold=True,
)
sys.exit(1)
mdl = ErsiliaModel(
model_id,
output_source=output_source,
service_class=service_class,
config_json=None,
)
result = mdl.run(input=input, output=output, batch_size=batch_size)
iter_values = []
if isinstance(result, types.GeneratorType):
for result in mdl.run(input=input, output=output, batch_size=batch_size):
if result is not None:
iter_values.append(result)
echo(
f"✅ Output successfully written in {output} file!",
fg="green",
bold=False,
)
return run