Source code for ttt.cli.train

import fire
from loguru import logger

from ttt.training import train


[docs]def train_agent(output_path: str, lr: float = 0.1, exploration_iterations: int = 2500, exploitation_iterations: int = 1500, exploration_exploitation_iterations: int = 1000) -> None: """ Trains a new agent. Parameters ---------- output_path: str Where to save agent's weights. E.g., weights.json. lr: float Learning rate. exploration_iterations: int Number of iterations during exploration phase. exploitation_iterations: int Number of iterations during exploitation phase. exploration_exploitation_iterations: int Number of iterations during exploration-exploitation phase. """ train(output_path, lr, exploration_iterations, exploitation_iterations, exploration_exploitation_iterations)
[docs]@logger.catch def main(): fire.Fire(train_agent)
if __name__ == "__main__": main()