import json
import os
import socket
import subprocess
import typing as ty
from pathlib import Path
from ablator.config.main import ConfigBase, configclass
from ablator.config.types import Optional
from ablator.modules.loggers.file import FileLogger
[docs]@configclass
class GcpConfig(ConfigBase):
"""
Configuration for Google Cloud Storage.
Attributes
----------
bucket : str
The bucket to use.
exclude_glob : Optional[str]
A glob to exclude from the rsync.
exclude_chkpts : bool
Whether to exclude checkpoints from the rsync.
"""
bucket: str
exclude_glob: Optional[str] = None
exclude_chkpts: bool = False
[docs] def __init__(self, *args, **kwargs):
"""
Initialize the GcpConfig class for managing Google Cloud Platform configurations.
Parameters
----------
*args
Positional arguments.
**kwargs
Keyword arguments.
Raises
------
AssertionError
If the GCP instance is not found.
Notes
-----
The IP address check avoids overly generic hostnames to match with existing instances.
"""
super().__init__(*args, **kwargs)
self.bucket = self.bucket.lstrip("gs:").lstrip("/").rstrip("/")
self.list_bucket()
hostname = socket.gethostname()
nodes = self._find_gcp_nodes(hostname)
# The IP address check avoids overly generic hostnames to match with existing instances.
ip_address = socket.gethostbyname(hostname)
assert (
len(nodes) == 1
and sum(
network_interface["networkIP"] == ip_address
for network_interface in nodes[0]["networkInterfaces"]
)
== 1
), "Can only use GcpConfig from Google Cloud Server. Consider switching to RemoteConfig."
def _make_cmd_up(self, local_path: Path, destination: str):
"""
Make the command to upload files to the bucket.
Parameters
----------
local_path : Path
The local path to upload.
destination : str
Bucket path.
Returns
-------
list[str]
The command to upload the file.
"""
destination = str(Path(self.bucket) / destination / local_path.name)
src = local_path
cmd = ["gsutil", "-m", "rsync", "-r"]
if self.exclude_glob is not None:
cmd += ["--exclude", f"{self.exclude_glob}"]
if self.exclude_chkpts:
cmd += ["--exclude", "*.pt"]
cmd += [f"{src}", f"gs://{destination}"]
return cmd
def _make_cmd_down(self, src_path: str, local_path: Path):
"""
Make the command to download files from the bucket.
Parameters
----------
src_path : str
The source path in the bucket.
local_path : Path
The local path to download to.
Returns
-------
list[str]
The command to download the file.
"""
src = Path(self.bucket) / src_path / local_path.name
destination = local_path
cmd = ["gsutil", "-m", "rsync", "-r"]
cmd += [f"gs://{src}", f"{destination}"]
return cmd
[docs] def list_bucket(self, destination: str | None = None):
"""
List the contents of a bucket. If destination is None, list the bucket itself.
Parameters
----------
destination : str | None
Bucket path.
Returns
-------
list[str]
List of files in the bucket.
"""
destination = str(
Path(self.bucket) / destination
if destination is not None
else Path(self.bucket)
)
cmd = ["gsutil", "ls", f"gs://{destination}"]
p = self._make_process(cmd, verbose=False)
stdout, stderr = p.communicate()
assert len(stderr) == 0, (
f"There was an error running `{' '.join(cmd)}`. "
"Make sure gsutil is installed and that the destination exists. "
f"`{stderr.decode('utf-8').strip()}`"
)
return stdout.decode("utf-8").strip().split("\n")
[docs] def rsync_up(
self,
local_path: Path,
remote_path: str,
logger: FileLogger | None = None,
):
"""
Rsync files to the bucket.
Parameters
----------
local_path : Path
The local path to upload.
remote_path : str
The destination path in the bucket.
logger : FileLogger | None
The logger to use.
Raises
------
AssertionError
If the rsync fails.
"""
cmd = self._make_cmd_up(local_path, remote_path)
p = self._make_process(cmd, verbose=logger is not None)
hostname = socket.gethostname()
if logger is not None:
logger.info(f"Rsync {hostname}:{cmd[-2]} to {cmd[-1]}")
p.wait()
def _make_process(self, cmd, verbose) -> subprocess.Popen:
"""
Make a subprocess.Popen object.
Parameters
----------
cmd : list[str]
The command to run.
verbose : bool
Whether to print the output.
Returns
-------
subprocess.Popen
The process object.
"""
if verbose:
stdout = subprocess.DEVNULL
stderr = subprocess.DEVNULL
else:
stdout = subprocess.PIPE
stderr = subprocess.PIPE
p = subprocess.Popen(cmd, stdout=stdout, stderr=stderr, preexec_fn=os.setsid)
return p
def _find_gcp_nodes(self, node_hostname: None | str = None) -> list[dict[str, ty.Any]]:
"""
Find the GCP instances with the given hostname.
Parameters
----------
node_hostname : None | str
The hostname of the node to find. If None, find all nodes.
Returns
-------
list[dict[str, ty.Any]]
List of GCP instances.
Raises
------
AssertionError
no nodes are found.
"""
cmd = ["gcloud", "compute", "instances", "list"]
if node_hostname is not None:
cmd += ["--filter", f'"{node_hostname}"']
cmd += ["--format", "json"]
p = self._make_process(cmd, verbose=False)
stdout, stderr = p.communicate()
assert len(stderr) == 0 and len(stdout) > 0
return json.loads(stdout.decode("utf-8"))
[docs] def rsync_down(
self,
remote_path: str,
local_path: Path,
logger: FileLogger | None = None,
verbose=True,
):
"""
Rsync files from the bucket.
Parameters
----------
remote_path : str
The source path in the bucket.
local_path : Path
The local path to download to.
logger : FileLogger | None
The logger to use.
verbose : bool
Whether to print the output.
"""
cmd = self._make_cmd_down(remote_path, local_path)
p = self._make_process(cmd, verbose)
hostname = socket.gethostname()
if logger is not None:
logger.info(f"Rsync {cmd[-2]} to {hostname}:{cmd[-1]}")
p.wait()
[docs] def rsync_down_node(
self,
node_hostname,
remote_path: str,
local_path: Path,
logger: FileLogger | None = None,
verbose=True,
):
"""
Rsync files from the bucket to all nodes with the given hostname.
Parameters
----------
node_hostname : str
The hostname of the nodes to rsync to.
remote_path : str
The source path in the bucket.
local_path : Path
The local path to download to.
logger : FileLogger | None
The logger to use.
verbose : bool
Whether to print the output.
"""
nodes = self._find_gcp_nodes(node_hostname)
ps: list[subprocess.Popen] = []
for node in nodes:
zone = node["zone"].split("/")[-1]
name = node["name"]
rsync_cmd = self._make_cmd_down(remote_path, local_path)
cmd = [
"gcloud",
"compute",
"ssh",
name,
"--zone",
zone,
"--tunnel-through-iap",
"--quiet",
"--",
"mkdir",
"-p",
f"{local_path};",
] + rsync_cmd
p = self._make_process(cmd, verbose)
ps.append(p)
if logger is not None:
logger.info(f"Rsync {cmd[-2]} to {name}:{cmd[-1]}")
for p in ps:
p.wait()