Source code for ersilia.cli.commands.run

import json
import os
import sys
import types

import click

from ... import ErsiliaModel
from ...core.session import Session
from ...utils.exceptions_utils.api_exceptions import UnprocessableInputError
from ...utils.terminal import is_quoted_list
from .. import echo
from . import ersilia_cli


[docs] def run_cmd(): """ Runs a specified model. This command allows users to run a specified model with given inputs. Returns ------- function The run command function to be used by the CLI and for testing in the pytest. Examples -------- .. code-block:: console Run a model by its ID with input data: $ ersilia run -i <input_data> --as-table Run a model with batch size: $ ersilia run -i <input_data> -b 50 """ def validate_input_output_types(input, output): if (type(input) == str and not input.endswith(".csv")) or is_quoted_list( json.dumps(input) ): echo( "String and list input types are not allowed in Ersilia. Please a csv input instead", fg="red", bold=True, ) sys.exit(1) if output is not None and not any( [output.endswith(ext) for ext in (".csv", ".h5", ".json")] ): echo( "This output type is not allowed in Ersilia. A valid output types are .csv, .h5 or .json", fg="red", bold=True, ) sys.exit(1) if output is None: echo( "Please specify a valid output types which are .csv, .h5 or .json", fg="red", bold=True, ) sys.exit(1) # Example usage: ersilia run -i {INPUT} [-o {OUTPUT} -b {BATCH_SIZE}] @ersilia_cli.command(short_help="Run a served model", help="Run a served model") @click.option("-i", "--input", "input", required=True, type=click.STRING) @click.option( "-o", "--output", "output", required=False, default=None, type=click.STRING ) @click.option( "-b", "--batch_size", "batch_size", required=False, default=100, type=click.INT ) def run(input, output, batch_size): validate_input_output_types(input, output) session = Session(config_json=None) model_id = session.current_model_id() service_class = session.current_service_class() output_source = session.current_output_source() print(f"Session: {session._session_dir}") print(f"Model id: {model_id}") print(f"Service class: {service_class}") print(f"Output source: {output_source}") if model_id is None: echo( "No model seems to be served. Please run 'ersilia serve ...' before.", fg="red", ) return mdl = ErsiliaModel( model_id, output_source=output_source, service_class=service_class, config_json=None, ) try: print(output) result = mdl.run(input=input, output=output, batch_size=batch_size) print(f"Result: {result}") iter_values = [] if isinstance(result, types.GeneratorType): for result in mdl.run( input=input, output=output, batch_size=batch_size ): if result is not None: iter_values.append(result) echo( f"✅ The output successfully generated in {output} file!", fg="green", bold=True, ) except UnprocessableInputError as e: echo(f"❌ Error: {e.message}", fg="red") echo(f"💡 {e.hints}") if output and os.path.exists(output): os.remove(output) return return run