forked from ayaka14732/llama-2-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpodrun
More file actions
77 lines (63 loc) · 3.02 KB
/
podrun
File metadata and controls
77 lines (63 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/python3
import argparse
import fabric
import os
def shell_escape(arg):
return "'%s'" % (arg.replace(r"'", r"'\''"),)
def get_ips():
with open(os.path.expanduser('~/podips.txt'), 'r') as f:
return [line.rstrip('\n') for line in f]
def run_command(hosts, command, activate_venv=False):
with fabric.ThreadingGroup(*hosts) as group:
try:
if activate_venv:
command = f'. venv/bin/activate; {command}'
group.run(command)
except Exception as e:
print(f"An error occurred while executing the command: {command}. Error: {str(e)}")
setup_commands_pyt = [
"sudo add-apt-repository -y ppa:deadsnakes/ppa",
"sudo apt update",
"sudo apt install python3.11-full -y",
"python3.11 -m venv venv",
]
setup_cmd= [
"pip install -U pip",
"pip install -U wheel",
"pip install jupyter notebook",
"git clone https://github.com/divyapatel4/llama-2-jax.git",
"pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu",
"pip install git+https://github.com/huggingface/transformers.git",
"pip install git+https://github.com/deepmind/optax.git",
"pip install -r llama-2-jax/requirements.txt",
"sudo apt-get install libpython3.11",
"python -c \"from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')\"",
]
# "export TRANSFORMERS_CACHE=/mnt/mydisk1/huggingface/cache",
def main():
parser = argparse.ArgumentParser(description='A helper script to execute commands on multiple hosts of a TPU pod.')
parser.add_argument('-i', '--include-local', action='store_true', help='include local host (127.0.0.1) in the host list')
parser.add_argument('-c', '--clean-up', action='store_true', help='clean up temporary files generated by previous TPU processes before executing the command')
parser.add_argument('-w', '--cwd', action='store_true', help='run the command in the current working directory, assuming the directory exists on all hosts')
parser.add_argument('-v', '--venv', action='store_true', help='activate the virtual environment before running the command')
parser.add_argument('command', nargs=argparse.ONE_OR_MORE)
args = parser.parse_args()
hosts = get_ips()
if args.include_local:
hosts.append('127.0.0.1')
command = ' '.join(shell_escape(command) for command in args.command)
print(f'Executing command:{command}')
if command == "\'setup\'":
for cmd in setup_commands_pyt:
run_command(hosts, cmd)
for cmd in setup_cmd:
run_command(hosts, cmd, activate_venv=True)
else:
if args.clean_up:
command = f'sudo rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs; {command}'
if args.cwd:
cwd = os.getcwd()
command = f'cd {cwd}; {command}'
run_command(hosts, command, activate_venv=args.venv)
if __name__ == '__main__':
main()