#!/usr/bin/env python3
import argparse
import subprocess
import venv
from pathlib import Path

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_VENV_DIR = _PROGRAM_DIR / ".venv"

parser = argparse.ArgumentParser()
parser.add_argument("--dev", action="store_true", help="Install dev requirements")
parser.add_argument(
    "--train", action="store_true", help="Install training requirements"
)
parser.add_argument("--zh", action="store_true", help="Install Chinese requirements")
parser.add_argument(
    "--torch-cpu", action="store_true", help="Install CPU-only version of PyTorch"
)
args = parser.parse_args()

# Create virtual environment
builder = venv.EnvBuilder(with_pip=True)
context = builder.ensure_directories(_VENV_DIR)
builder.create(_VENV_DIR)

# Upgrade dependencies
pip = [context.env_exe, "-m", "pip"]
subprocess.check_call(pip + ["install", "--upgrade", "pip"])
subprocess.check_call(pip + ["install", "--upgrade", "setuptools", "wheel"])

# PyTorch
if args.torch_cpu:
    subprocess.check_call(
        pip
        + [
            "install",
            "torch>=2,<3",
            "--extra-index-url",
            "https://download.pytorch.org/whl/cpu",
        ]
    )

# Install requirements
extras = []
if args.dev:
    extras.append("dev")

if args.train:
    extras.append("train")

if args.zh:
    extras.append("zh")

extras_str = ""
if extras:
    extras_str = "[" + ",".join(extras) + "]"

subprocess.check_call(pip + ["install", "-e", str(_PROGRAM_DIR) + extras_str])
