Source code for ersilia.publish.rebase

import os
import shutil

from .. import ErsiliaBase
from ..default import GITHUB_ORG
from ..utils.terminal import run_command


class _FileFolderRebaser(object):
    def __init__(self, model_path, template_path):
        self.model_path = model_path
        self.template_path = template_path

    def do_file(self, name, overwrite):
        src = os.path.join(self.template_path, name)
        trg = os.path.join(self.model_path, name)
        if not os.path.exists(src):
            return
        if overwrite:
            shutil.copyfile(src, trg)
        else:
            if os.path.exists(trg):
                return
            else:
                shutil.copyfile(src, trg)

    def do_folder(self, name, overwrite):
        src = os.path.join(self.template_path, name)
        trg = os.path.join(self.model_path, name)
        if not os.path.exists(src):
            return
        if overwrite:
            shutil.copytree(src, trg)
        else:
            if os.path.exists(trg):
                return
            else:
                shutil.copytree(src, trg)


[docs] class TemplateRebaser(ErsiliaBase): """ Class for rebasing model repositories with a template repository. Parameters ---------- model_id : str The ID of the model to be rebased. template_repo : str, optional The name of the template repository. Default is 'eos-template'. config_json : str, optional Path to the configuration JSON file. credentials_json : str, optional Path to the credentials JSON file. """ def __init__( self, model_id: str, template_repo="eos-template", config_json=None, credentials_json=None, ): ErsiliaBase.__init__( self, config_json=config_json, credentials_json=credentials_json ) self.model_id = model_id self.template_repo = template_repo self.root = os.path.abspath(self._tmp_dir) self.cwd = os.getcwd() self.model_path = os.path.join(self.root, self.model_id) self.template_path = os.path.join(self.root, self.template_repo) self.clean() self.file_folder_rebaser = _FileFolderRebaser( self.model_path, self.template_path )
[docs] def clone_template(self): """ Clone the template repository. """ os.chdir(self.root) run_command("gh repo clone {0}/{1}".format(GITHUB_ORG, self.template_repo)) os.chdir(self.cwd)
[docs] def clone_current_model(self): """ Clone the current model repository. """ os.chdir(self.root) run_command("gh repo clone {0}/{1}".format(GITHUB_ORG, self.model_id)) os.chdir(self.cwd)
[docs] def dvc_part(self): """ Set up DVC (Data Version Control) for the model repository. """ self.file_folder_rebaser.do_file("data.h5", overwrite=False) self.file_folder_rebaser.do_file(".dvcignore", overwrite=False) self.file_folder_rebaser.do_folder(".dvc", overwrite=False)
[docs] def clean(self): """ Clean up temporary directories. """ if os.path.exists(self.template_path): self.logger.debug("Cleaning {0}".format(self.template_path)) run_command("rm -rf {0}".format(self.template_path)) if os.path.exists(self.model_path): self.logger.debug("Cleaning {0}".format(self.model_path)) run_command("rm -rf {0}".format(self.model_path))
# TODO: Add other rebasing options
[docs] def rebase(self): """ Rebase the model repository with the template repository. """ self.clone_template() self.clone_current_model() self.dvc_part()