Skip to content

Commit

Permalink
Add create command
Browse files Browse the repository at this point in the history
  • Loading branch information
shawwn committed Feb 15, 2021
1 parent c065e7e commit b700a8c
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 15 deletions.
69 changes: 69 additions & 0 deletions tpunicorn/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import random
import math
from pprint import pprint as pp

import logging as pylogging
Expand Down Expand Up @@ -149,6 +150,74 @@ def do_step(label=None, command=None, dry_run=False, delay_after=1.0, args=(), k
time.sleep(delay_after)
return result

@cli.command()
@click.argument('tpu', type=click.STRING, metavar="[TPU; default\"automatic\"]", autocompletion=complete_tpu_id)
@click.option('--zone', type=click.Choice(tpunicorn.tpu.get_tpu_zones()))
@click.option('-v', '--version', type=click.STRING, metavar="[VERSION; default=\"1.15.3\"]", default="1.15.3",
help="By default, the TPU is imaged with TF version 1.15.3, as we've found this to be the most stable over time."
" If your TPU is giving strange errors, try setting this to `nightly` and pray to moloch.")
@click.option('-a', '--accelerator-type', metavar="[ACCELERATOR_TYPE; default=\"v2-8\"]", type=click.STRING, default=None)
@click.option('--async', 'async_', is_flag=True)
@click.option('-d', '--description', metavar="DESCRIPTION", type=click.STRING, default=None)
@click.option('-n', '--network', metavar="[NETWORK; default=\"default\"]", type=click.STRING, default="default")
@click.option('-p/-np', '--preemptible/--non-preemptible', default=True)
@click.option('-r', '--range', metavar="[RANGE]", type=click.STRING, default=None)
@click.option('-p', '--project', metavar="[PROJECT]", type=click.STRING, default=None)
@click.option('-y', '--yes', is_flag=True)
@click.option('--dry-run', is_flag=True)
def create(tpu, zone, version, accelerator_type, async_, description, network, preemptible, range, project, yes, dry_run):
index = tpunicorn.tpu.parse_tpu_index(tpu)
if accelerator_type is None:
accelerator_type = tpunicorn.tpu.parse_tpu_accelerator_type(tpu)
# parse the TPU type and core count.
tpu_type, cores = accelerator_type.rsplit("-", 1)
tpu_type = tpu_type.lower()
cores = int(cores)
# give reasonable defaults for TFRC members
if zone is None:
zone = tpunicorn.tpu.parse_tpu_zone(tpu)
if zone is None:
if tpu_type == "v2" and cores == 8:
zone = "us-central1-f"
elif tpu_type == "v2" and cores > 8:
zone = "us-central1-a"
elif tpu_type == "v3":
zone = "europe-west4-a"
else:
raise ValueError("Please specify --zone")
if range is None and index >= 0:
if cores == 8:
range = "10.48.{i}.0/29".format(i=index)
else:
i=index + 2
cidr=int(32 + 2 - math.log2(cores))
range="10.{i}.0.0/{cidr}".format(i=i, cidr=cidr)
if range.startswith("10.48.") and cores > 8:
raise ValueError("The range {range!r} conflicts with the default 10.48.* range of v2-8's and v3-8's. I decided to raise an error rather than a warning, because we rely on this specific range for our own internal networking. If you're making a TPU pod, try a different index other than {index}. If you really, really wanted to use 10.48.* for you TPU pods, I'm very sorry; ping me on twitter (@theshawwn) and I'll change this.".format(range=range, index=index))
try:
index = int(tpu)
# the TPU name is just an integer, so try to build a new name
# automatically for convenience.
zone_abbrev = tpunicorn.tpu.infer_zone_abbreviation(zone)
tpu = "tpu-{accelerator_type}-{zone_abbrev}-{index}".format(
accelerator_type=accelerator_type,
zone_abbrev=zone_abbrev,
index=index)
except ValueError:
pass
if project is None:
project = tpunicorn.tpu.get_default_project()
create = tpunicorn.create_tpu_command(tpu, zone=zone, version=version, accelerator_type=accelerator_type, async_=async_, description=description, network=network, preemptible=preemptible, range=range, project=project)
if not yes:
print_step('Step 1: create TPU.', create)
if not click.confirm('Proceed? {}'.format('(dry run)' if dry_run else '')):
return
do_step('Step 1: create TPU...', create, dry_run=dry_run)
click.echo('TPU {} {} created.'.format(
tpunicorn.tpu.parse_tpu_id(tpu),
'would be' if dry_run else 'is'))


@cli.command()
@click.argument('tpu', type=click.STRING, autocompletion=complete_tpu_id)
@click.option('--zone', type=click.Choice(tpunicorn.tpu.get_tpu_zones()))
Expand Down
142 changes: 127 additions & 15 deletions tpunicorn/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,110 @@ def parse_tpu_index(tpu):
idx = int(idx[0])
return idx

def parse_tpu_accelerator_type(tpu):
fqn = tpu if isinstance(tpu, str) else tpu['name']
accelerator_type = re.findall(r'(v[0-9]+[-][0-9]+)', fqn)
if len(accelerator_type) <= 0:
return "v2-8"
else:
return accelerator_type[0]

def parse_tpu_zone(tpu):
fqn = tpu if isinstance(tpu, str) else tpu['name']
zone_abbreviation = re.findall(r'[-]([^-]+)[-](?:v[0-9]+[-][0-9]+)', fqn)
# I might clean this up someday, but probably not. Sorry that this looks so cryptic.
results = [[expand_zone_abbreviations(k), re.findall(r'\b{}\b'.format(k), fqn)]
for k, v in get_zone_abbreviations(only_unambiguous_results=True).items()
if len(v) <= 1]
for zone, matched in results:
if matched:
return zone

from collections import defaultdict

country_abbrevs = {
'as': 'asia',
'eu': 'europe',
'au': 'austrailia',
'us': 'us',
'na': 'northamerica',
'sa': 'southamerica',
}

region_abbrevs = {
'n': 'north',
's': 'south',
'e': 'east',
'w': 'west',
'c': 'central',
'ne': 'northeast',
'nw': 'northwest',
'se': 'southeast',
'sw': 'southwest',
}

@ring.lru(expire=3600) # cache tpu abbrevs for an hour
def get_zone_abbreviations(full_zone_names=None, only_unambiguous_results=False): # e.g. ['europe-west4-a']
if full_zone_names is None:
full_zone_names = get_tpu_zones()
if isinstance(full_zone_names, str):
full_zone_names = full_zone_names.split(',')
results = defaultdict(lambda: [])
for full_zone_name in full_zone_names:
country, region, zone_id = full_zone_name.split('-')
region, region_id = region[:-1], region[-1:]
assert int(region_id) in list(range(10))
for cshort, cfull in country_abbrevs.items():
for rshort, rfull in region_abbrevs.items():
if cfull == country and rfull == region:
# e.g. 'euw4a'
results[cshort + rshort + region_id + zone_id].append(full_zone_name)
if not only_unambiguous_results:
# e.g. 'euw4'
results[cshort + rshort + region_id].append(full_zone_name)
# e.g. 'euw'
results[cshort + rshort].append(full_zone_name)
# e.g. 'eu'
results[cshort].append(full_zone_name)
# e.g. '4'
results[region_id].append(full_zone_name)
# e.g. '4a'
results[region_id + zone_id].append(full_zone_name)
# e.g. 'w4'
results[rshort + region_id].append(full_zone_name)
# e.g. 'w'
results[rshort].append(full_zone_name)
# e.g. 'west4'
results[rfull + region_id].append(full_zone_name)
# e.g. 'west'
results[rfull].append(full_zone_name)
return dict(results)

def infer_zone_abbreviation(zone):
# I might clean this up someday, but probably not. Sorry that this looks so cryptic.
return list(get_zone_abbreviations(zone, only_unambiguous_results=True).keys())[0]

def expand_zone_abbreviations(zone):
if zone is None:
return zone
results = []
for zone in zone.split(','):
for expansion in get_zone_abbreviations().get(zone, [zone]):
if expansion not in results:
results.append(expansion)
return ','.join(results)

def get_tpu_zone_choices(project=None):
choices = []
for abbrev, expansions in get_zone_abbreviations().items():
for expansion in expansions:
if expansion not in choices:
choices.append(expansion)
if abbrev not in choices:
choices.append(abbrev)
return choices


def parse_tpu_network(tpu):
net = tpu if isinstance(tpu, str) else tpu['network']
return net.split('/')[-1]
Expand Down Expand Up @@ -362,25 +466,33 @@ def format(tpu, spec=None, formatter=NamespaceFormatter, project=None):
spec = get_default_format_spec(thin=len(format_widths(project=project)) == 0)
return fmt.format(spec)

def create_tpu_command(tpu, zone=None, project=None, version=None, description=None, preemptible=None, async_=False):
if zone is None:
zone = parse_tpu_zone(tpu)
if project is None:
project = parse_tpu_project(tpu)
if version is None:
version = parse_tpu_version(tpu)
if description is None:
description = parse_tpu_description(tpu)
if preemptible is None:
preemptible = True if parse_tpu_preemptible(tpu) else None
def create_tpu_command(tpu=None, zone=None, version=None, accelerator_type=None, project=None, description=None, network=None, range=None, preemptible=None, async_=False):
name = parse_tpu_id(tpu)
if not isinstance(tpu, str):
if zone is None:
zone = parse_tpu_zone(tpu)
if project is None:
project = parse_tpu_project(tpu)
if version is None:
version = parse_tpu_version(tpu)
if accelerator_type is None:
accelerator_type = parse_tpu_type(tpu)
if description is None:
description = parse_tpu_description(tpu)
if preemptible is None:
preemptible = True if parse_tpu_preemptible(tpu) else None
if network is None:
network = parse_tpu_network(tpu)
if range is None:
range = parse_tpu_range(tpu)
return build_commandline("gcloud compute tpus create",
parse_tpu_id(tpu),
name,
zone=zone,
project=project,
network=parse_tpu_network(tpu),
range=parse_tpu_range(tpu),
network=network,
range=range,
version=version,
accelerator_type=parse_tpu_type(tpu),
accelerator_type=accelerator_type,
preemptible=preemptible,
description=description,
async_=async_,
Expand Down

0 comments on commit b700a8c

Please sign in to comment.