Source code for ersilia.cli.commands.run

import json
import sys
import types

import click

from .. import echo
from . import ersilia_cli


[docs] def run_cmd(): """ Runs a specified model. This command allows users to run a specified model with given inputs. Returns ------- function The run command function to be used by the CLI and for testing in the pytest. Examples -------- .. code-block:: console Run a model by its ID with input data: $ ersilia run -i <input_data> --as-table Run a model with batch size: $ ersilia run -i <input_data> -b 50 """ def validate_input_output_types(input, output): from ...utils.terminal import is_quoted_list if (type(input) == str and not input.endswith(".csv")) or is_quoted_list( json.dumps(input) ): echo( "Input must be a single-column CSV file. String and list inputs are not supported.", fg="red", bold=True, ) sys.exit(1) if output is not None and not any( [output.endswith(ext) for ext in (".csv", ".h5")] ): echo( "This output type is not allowed in Ersilia. Valid output types are .csv or .h5", fg="red", bold=True, ) sys.exit(1) if output is None: echo( "Please specify a valid output file with extension .csv or .h5", fg="red", bold=True, ) sys.exit(1) # Example usage: ersilia run -i {INPUT} [-o {OUTPUT} -b {BATCH_SIZE}] @ersilia_cli.command( short_help="Run predictions on the served model", help="Run predictions using the currently served model. Input must be a single-column CSV file. Output can be saved as .csv or .h5. A model must be served before running this command.", ) @click.option( "-i", "--input", "input", required=True, type=click.STRING, help="Path to a single-column CSV file containing the input data.", ) @click.option( "-o", "--output", "output", required=True, default=None, type=click.STRING, help="Path to the output file. Accepted formats: .csv, .h5.", ) @click.option( "-b", "--batch_size", "batch_size", required=False, default=100, type=click.INT, help="Number of inputs processed per batch.", ) def run(input, output, batch_size): import os import re from ... import ErsiliaModel from ...core.session import Session validate_input_output_types(input, output) session = Session(config_json=None) model_id = session.current_model_id() service_class = session.current_service_class() output_source = session.current_output_source() if model_id is None: echo( "No model seems to be served. Please run 'ersilia serve ...' before.", fg="red", ) return output_basename = os.path.basename(output) output_model_ids = re.findall(r"eos[0-9][a-z0-9]{3}", output_basename) if output_model_ids and output_model_ids[0] != model_id: echo( f"Output filename contains model identifier '{output_model_ids[0]}' but the served model is '{model_id}'. Please use a correct output filename.", fg="red", bold=True, ) sys.exit(1) mdl = ErsiliaModel( model_id, output_source=output_source, service_class=service_class, config_json=None, ) result = mdl.run(input=input, output=output, batch_size=batch_size) iter_values = [] if isinstance(result, types.GeneratorType): for result in mdl.run(input=input, output=output, batch_size=batch_size): if result is not None: iter_values.append(result) echo( f"✅ Output successfully written in {output} file!", fg="green", bold=False, ) return run