Coverage for sfkit/utils/sfgwas_protocol.py: 100%
206 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 15:59 -0400
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 15:59 -0400
1"""
2Run the sfgwas protocol
3"""
5import copy
6import fileinput
7import os
8import random
9import shutil
10import time
12import tomlkit
13from nacl.encoding import HexEncoder
14from nacl.public import Box, PrivateKey, PublicKey
16from sfkit.api import get_doc_ref_dict, update_firestore
17from sfkit.utils import constants
18from sfkit.utils.helper_functions import condition_or_fail, run_command
19from sfkit.utils.sfgwas_helper_functions import (
20 get_file_paths,
21 post_process_results,
22 run_sfgwas_with_task_updates,
23 to_float_int_or_bool,
24 use_existing_config,
25)
28def run_sfgwas_protocol(role: str, phase: str = "", demo: bool = False) -> None:
29 """
30 Run the sfgwas protocol
31 :param role: 0, 1, 2, ...
32 :param phase: "", "1", "2", "3"
33 :param demo: True or False
34 """
35 if not (constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT):
36 install_sfgwas()
37 if not demo:
38 generate_shared_keys(int(role))
39 print("Begin updating config files")
40 update_config_local(role)
41 update_config_global()
42 update_config_global_phase(phase, demo)
43 if not (constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT):
44 build_sfgwas()
45 start_sfgwas(role, demo)
48def install_sfgwas() -> None:
49 """
50 Install sfgwas and its dependencies
51 """
52 update_firestore("update_firestore::task=Installing dependencies")
53 print("Begin installing dependencies")
55 plink2_download_link = "https://s3.amazonaws.com/plink2-assets/plink2_linux_avx2_latest.zip"
56 plink2_zip_file = plink2_download_link.split("/")[-1]
58 run_command("sudo apt-get update -y")
59 run_command("sudo apt-get install wget git zip unzip -y")
61 print("Installing go")
62 max_retries = 3
63 retries = 0
64 while retries < max_retries:
65 run_command("rm -f https://golang.org/dl/go1.18.1.linux-amd64.tar.gz")
66 run_command("wget -nc https://golang.org/dl/go1.18.1.linux-amd64.tar.gz")
67 run_command("sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.18.1.linux-amd64.tar.gz")
68 if os.path.isdir("/usr/local/go"):
69 break
70 retries += 1
71 if not os.path.isdir("/usr/local/go"):
72 condition_or_fail(False, "go failed to install")
73 print("go successfully installed")
75 run_command(f"wget -nc {plink2_download_link}")
76 run_command(f"sudo unzip -o {plink2_zip_file} -d /usr/local/bin")
77 run_command("pip3 install numpy")
79 # make sure plink2 successfully installed
80 condition_or_fail(
81 os.path.isfile("/usr/local/bin/plink2"),
82 "plink2 not installed (probably need to get new version)",
83 )
85 if os.path.isdir("lattigo"):
86 print("lattigo already exists")
87 else:
88 print("Installing lattigo")
89 run_command("git clone https://github.com/hcholab/lattigo && cd lattigo && git switch lattigo_pca")
91 if os.path.isdir("mpc-core"):
92 print("mpc-core already exists")
93 else:
94 print("Installing mpc-core")
95 run_command("git clone https://github.com/hhcho/mpc-core")
97 if os.path.isdir("sfgwas"):
98 print("sfgwas already exists")
99 else:
100 print("Installing sfgwas")
101 run_command("git clone https://github.com/hcholab/sfgwas && cd sfgwas")
103 print("Finished installing dependencies")
106def generate_shared_keys(role: int) -> None:
107 """
108 Generate shared keys for the sfgwas protocol
109 :param role: 0, 1, 2, ...
110 """
111 doc_ref_dict: dict = get_doc_ref_dict()
112 update_firestore("update_firestore::task=Generating cryptographic keys")
113 print("Generating shared keys...")
115 private_key_path = os.path.join(constants.SFKIT_DIR, "my_private_key.txt")
116 with open(private_key_path, "r") as f:
117 my_private_key = PrivateKey(f.readline().rstrip().encode(), encoder=HexEncoder)
119 for i, other_username in enumerate(doc_ref_dict["participants"]):
120 if i == role:
121 continue
122 other_public_key_str: str = doc_ref_dict["personal_parameters"][other_username]["PUBLIC_KEY"]["value"]
123 while not other_public_key_str:
124 if other_username == "Broad":
125 print("Waiting for the Broad (CP0) to set up...")
126 else:
127 print(f"No public key found for {other_username}. Waiting...")
128 time.sleep(5)
129 doc_ref_dict = get_doc_ref_dict()
130 other_public_key_str = doc_ref_dict["personal_parameters"][other_username]["PUBLIC_KEY"]["value"]
131 other_public_key = PublicKey(other_public_key_str.encode(), encoder=HexEncoder)
132 condition_or_fail(my_private_key != other_public_key, "Private and public keys must be different")
133 shared_key = Box(my_private_key, other_public_key).shared_key()
134 shared_key_path = os.path.join(constants.SFKIT_DIR, f"shared_key_{min(role, i)}_{max(role, i)}.bin")
135 with open(shared_key_path, "wb") as f:
136 f.write(shared_key)
138 random.seed(doc_ref_dict["personal_parameters"]["Broad"]["PUBLIC_KEY"]["value"])
139 global_shared_key = random.getrandbits(256).to_bytes(32, "big")
140 with open(os.path.join(constants.SFKIT_DIR, "shared_key_global.bin"), "wb") as f:
141 f.write(global_shared_key)
143 print(f"Shared keys generated and saved to {constants.SFKIT_DIR}.")
146def update_config_local(role: str, protocol: str = "gwas") -> None:
147 """
148 Update configLocal.Party{role}.toml
149 :param role: 0, 1, 2, ...
150 """
151 doc_ref_dict: dict = get_doc_ref_dict()
152 if constants.BLOCKS_MODE in doc_ref_dict["description"]:
153 use_existing_config(role, doc_ref_dict)
154 return
156 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configLocal.Party{role}.toml"
158 try:
159 with open(config_file_path, "r") as f:
160 data = tomlkit.parse(f.read())
162 except FileNotFoundError:
163 print(f"File {config_file_path} not found.")
164 print("Creating it...")
165 shutil.copyfile(
166 f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configLocal.Party2.toml", config_file_path
167 )
168 with open(config_file_path, "r") as f:
169 data = tomlkit.parse(f.read())
171 if role != "0":
172 update_data_file_paths(data)
173 data["shared_keys_path"] = constants.SFKIT_DIR
174 data["output_dir"] = f"out/party{role}"
175 data["cache_dir"] = f"cache/party{role}"
177 doc_ref_dict = get_doc_ref_dict()
178 user_id: str = doc_ref_dict["participants"][int(role)]
179 data["local_num_threads"] = int(doc_ref_dict["personal_parameters"][user_id]["NUM_CPUS"]["value"])
180 data["assoc_num_blocks_parallel"] = int(data.get("local_num_threads", 16)) // 8
181 data["memory_limit"] = int(int(data.get("local_num_threads", 16)) * 8 * 1_000_000_000)
183 with open(config_file_path, "w") as f:
184 f.write(tomlkit.dumps(data))
187def update_data_file_paths(data: dict) -> None:
188 geno_file_prefix, data_path = get_file_paths()
190 data["geno_binary_file_prefix"] = f"{geno_file_prefix}"
191 data["geno_block_size_file"] = f"{data_path}/chrom_sizes.txt"
192 data["pheno_file"] = f"{data_path}/pheno.txt"
193 data["covar_file"] = f"{data_path}/cov.txt"
194 data["snp_position_file"] = f"{data_path}/snp_pos.txt"
195 data["sample_keep_file"] = f"{data_path}/sample_keep.txt"
196 data["snp_ids_file"] = f"{data_path}/snp_ids.txt"
197 data["geno_count_file"] = f"{data_path}/all.gcount.transpose.bin"
199 # don't need to return anything because data is a mutable object
202def update_config_global(protocol: str = "gwas") -> None:
203 """
204 Update configGlobal.toml
205 """
206 print("Updating configGlobal.toml")
207 doc_ref_dict: dict = get_doc_ref_dict()
208 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configGlobal.toml"
209 with open(config_file_path, "r") as f:
210 data = tomlkit.parse(f.read())
212 # Update the ip addresses and ports
213 for i, participant in enumerate(doc_ref_dict["participants"]):
214 if f"party{i}" not in data.get("servers", {}):
215 data.get("servers", {})[f"party{i}"] = copy.deepcopy(data.get("servers", {})[f"party{i-1}"])
217 ip_addr = doc_ref_dict["personal_parameters"][participant]["IP_ADDRESS"]["value"]
218 data.get("servers", {}).get(f"party{i}", {})["ipaddr"] = ip_addr
220 ports: list = doc_ref_dict["personal_parameters"][participant]["PORTS"]["value"].split(",")
221 for j, port in enumerate(ports):
222 data.get("servers", {}).get(f"party{i}", {}).get("ports", {})[f"party{j}"] = port
224 if constants.BLOCKS_MODE not in doc_ref_dict["description"]:
225 data["num_main_parties"] = len(doc_ref_dict["participants"]) - 1
227 row_name = "num_rows" if protocol == "pca" else "num_inds"
228 col_name = "num_columns" if protocol == "pca" else "num_snps"
229 data[row_name] = []
230 for i, participant in enumerate(doc_ref_dict["participants"]):
231 data.get(row_name, []).append(int(doc_ref_dict["personal_parameters"][participant]["NUM_INDS"]["value"]))
232 print(f"{row_name} for {participant} is {data.get(row_name, [])[i]}")
233 condition_or_fail(i == 0 or data.get(row_name, [])[i] > 0, f"{row_name} must be greater than 0")
234 data[col_name] = (
235 int(doc_ref_dict["parameters"]["num_snps"]["value"])
236 if protocol == "gwas"
237 else int(doc_ref_dict["parameters"]["num_columns"]["value"])
238 )
239 print(f"{col_name} is {data[col_name]}")
240 condition_or_fail(data.get(col_name, 0) > 0, f"{col_name} must be greater than 0")
242 # shared and advanced parameters
243 pars = {**doc_ref_dict["parameters"], **doc_ref_dict["advanced_parameters"]}
244 for key, value in pars.items():
245 if key in data:
246 data[key] = to_float_int_or_bool(value["value"])
248 with open(config_file_path, "w") as f:
249 f.write(tomlkit.dumps(data))
252def update_config_global_phase(phase: str, demo: bool, protocol: str = "gwas") -> None:
253 """
254 Update based on phase in configGlobal.toml
255 :param phase: "1", "2", "3"
256 """
257 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configGlobal.toml"
258 with open(config_file_path, "r") as f:
259 data = tomlkit.parse(f.read())
261 data["phase"] = phase
262 if phase == "2":
263 data["use_cached_qc"] = True
264 elif phase == "3":
265 data["use_cached_qc"] = True
266 data["use_cached_pca"] = True
268 if demo:
269 data["num_power_iters"] = 2
270 data["iter_per_eigenval"] = 2
271 data["num_pcs_to_remove"] = 2
273 with open(config_file_path, "w") as f:
274 f.write(tomlkit.dumps(data))
277def update_sfgwas_go(protocol: str = "gwas") -> None:
278 """
279 Update sfgwas.go
280 """
281 for line in fileinput.input(f"{constants.EXECUTABLES_PREFIX}sfgwas/sfgwas.go", inplace=True):
282 if "CONFIG_PATH = " in line:
283 print(f'var CONFIG_PATH = "config/{protocol}"')
284 else:
285 print(line, end="")
288def build_sfgwas() -> None:
289 """
290 build/compile sfgwas
291 """
292 update_firestore("update_firestore::task=Compiling code")
293 print("Building sfgwas code")
294 command = """export PYTHONUNBUFFERED=TRUE && export PATH=$PATH:/usr/local/go/bin && export HOME=~ && export GOCACHE=~/.cache/go-build && cd sfgwas && go get -t github.com/hcholab/sfgwas && go build"""
295 run_command(command)
296 print("Finished building sfgwas code")
299def start_sfgwas(role: str, demo: bool = False, protocol: str = "SF-GWAS") -> None:
300 """
301 Start the actual sfgwas program
302 :param role: 0, 1, 2, ...
303 :param demo: True if running demo
304 """
305 update_firestore("update_firestore::task=Initiating Protocol")
306 print("Begin SF-GWAS protocol")
307 protocol_command = f"export PID={role} && go run sfgwas.go | tee stdout_party{role}.txt"
308 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT:
309 protocol_command = f"cd {constants.EXECUTABLES_PREFIX}sfgwas && PID={role} sfgwas | tee stdout_party{role}.txt"
310 if demo:
311 protocol_command = "bash run_example.sh"
312 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT:
313 # cannot use "go run" from run_example.sh in Docker, so reproducing that script in Python here
314 protocol_command = (
315 " & ".join(
316 f"(cd {constants.EXECUTABLES_PREFIX}sfgwas && PID={r} sfgwas | tee stdout_party{r}.txt)"
317 for r in range(3)
318 )
319 + " & wait $(jobs -p)"
320 )
321 command = f"export PYTHONUNBUFFERED=TRUE && export PATH=$PATH:/usr/local/go/bin && export HOME=~ && export GOCACHE=~/.cache/go-build && cd {constants.EXECUTABLES_PREFIX}sfgwas && {protocol_command}"
322 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT:
323 command = f"export PYTHONUNBUFFERED=TRUE && {protocol_command}"
325 run_sfgwas_with_task_updates(command, protocol, demo, role)
326 print(f"Finished {protocol} protocol")
328 if role == "0":
329 update_firestore("update_firestore::status=Finished protocol!")
330 return
332 post_process_results(role, demo, protocol)