Coverage for sfkit/utils/sfgwas_protocol.py: 100%

206 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-07 15:59 -0400

1""" 

2Run the sfgwas protocol 

3""" 

4 

5import copy 

6import fileinput 

7import os 

8import random 

9import shutil 

10import time 

11 

12import tomlkit 

13from nacl.encoding import HexEncoder 

14from nacl.public import Box, PrivateKey, PublicKey 

15 

16from sfkit.api import get_doc_ref_dict, update_firestore 

17from sfkit.utils import constants 

18from sfkit.utils.helper_functions import condition_or_fail, run_command 

19from sfkit.utils.sfgwas_helper_functions import ( 

20 get_file_paths, 

21 post_process_results, 

22 run_sfgwas_with_task_updates, 

23 to_float_int_or_bool, 

24 use_existing_config, 

25) 

26 

27 

28def run_sfgwas_protocol(role: str, phase: str = "", demo: bool = False) -> None: 

29 """ 

30 Run the sfgwas protocol 

31 :param role: 0, 1, 2, ... 

32 :param phase: "", "1", "2", "3" 

33 :param demo: True or False 

34 """ 

35 if not (constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT): 

36 install_sfgwas() 

37 if not demo: 

38 generate_shared_keys(int(role)) 

39 print("Begin updating config files") 

40 update_config_local(role) 

41 update_config_global() 

42 update_config_global_phase(phase, demo) 

43 if not (constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT): 

44 build_sfgwas() 

45 start_sfgwas(role, demo) 

46 

47 

48def install_sfgwas() -> None: 

49 """ 

50 Install sfgwas and its dependencies 

51 """ 

52 update_firestore("update_firestore::task=Installing dependencies") 

53 print("Begin installing dependencies") 

54 

55 plink2_download_link = "https://s3.amazonaws.com/plink2-assets/plink2_linux_avx2_latest.zip" 

56 plink2_zip_file = plink2_download_link.split("/")[-1] 

57 

58 run_command("sudo apt-get update -y") 

59 run_command("sudo apt-get install wget git zip unzip -y") 

60 

61 print("Installing go") 

62 max_retries = 3 

63 retries = 0 

64 while retries < max_retries: 

65 run_command("rm -f https://golang.org/dl/go1.18.1.linux-amd64.tar.gz") 

66 run_command("wget -nc https://golang.org/dl/go1.18.1.linux-amd64.tar.gz") 

67 run_command("sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.18.1.linux-amd64.tar.gz") 

68 if os.path.isdir("/usr/local/go"): 

69 break 

70 retries += 1 

71 if not os.path.isdir("/usr/local/go"): 

72 condition_or_fail(False, "go failed to install") 

73 print("go successfully installed") 

74 

75 run_command(f"wget -nc {plink2_download_link}") 

76 run_command(f"sudo unzip -o {plink2_zip_file} -d /usr/local/bin") 

77 run_command("pip3 install numpy") 

78 

79 # make sure plink2 successfully installed 

80 condition_or_fail( 

81 os.path.isfile("/usr/local/bin/plink2"), 

82 "plink2 not installed (probably need to get new version)", 

83 ) 

84 

85 if os.path.isdir("lattigo"): 

86 print("lattigo already exists") 

87 else: 

88 print("Installing lattigo") 

89 run_command("git clone https://github.com/hcholab/lattigo && cd lattigo && git switch lattigo_pca") 

90 

91 if os.path.isdir("mpc-core"): 

92 print("mpc-core already exists") 

93 else: 

94 print("Installing mpc-core") 

95 run_command("git clone https://github.com/hhcho/mpc-core") 

96 

97 if os.path.isdir("sfgwas"): 

98 print("sfgwas already exists") 

99 else: 

100 print("Installing sfgwas") 

101 run_command("git clone https://github.com/hcholab/sfgwas && cd sfgwas") 

102 

103 print("Finished installing dependencies") 

104 

105 

106def generate_shared_keys(role: int) -> None: 

107 """ 

108 Generate shared keys for the sfgwas protocol 

109 :param role: 0, 1, 2, ... 

110 """ 

111 doc_ref_dict: dict = get_doc_ref_dict() 

112 update_firestore("update_firestore::task=Generating cryptographic keys") 

113 print("Generating shared keys...") 

114 

115 private_key_path = os.path.join(constants.SFKIT_DIR, "my_private_key.txt") 

116 with open(private_key_path, "r") as f: 

117 my_private_key = PrivateKey(f.readline().rstrip().encode(), encoder=HexEncoder) 

118 

119 for i, other_username in enumerate(doc_ref_dict["participants"]): 

120 if i == role: 

121 continue 

122 other_public_key_str: str = doc_ref_dict["personal_parameters"][other_username]["PUBLIC_KEY"]["value"] 

123 while not other_public_key_str: 

124 if other_username == "Broad": 

125 print("Waiting for the Broad (CP0) to set up...") 

126 else: 

127 print(f"No public key found for {other_username}. Waiting...") 

128 time.sleep(5) 

129 doc_ref_dict = get_doc_ref_dict() 

130 other_public_key_str = doc_ref_dict["personal_parameters"][other_username]["PUBLIC_KEY"]["value"] 

131 other_public_key = PublicKey(other_public_key_str.encode(), encoder=HexEncoder) 

132 condition_or_fail(my_private_key != other_public_key, "Private and public keys must be different") 

133 shared_key = Box(my_private_key, other_public_key).shared_key() 

134 shared_key_path = os.path.join(constants.SFKIT_DIR, f"shared_key_{min(role, i)}_{max(role, i)}.bin") 

135 with open(shared_key_path, "wb") as f: 

136 f.write(shared_key) 

137 

138 random.seed(doc_ref_dict["personal_parameters"]["Broad"]["PUBLIC_KEY"]["value"]) 

139 global_shared_key = random.getrandbits(256).to_bytes(32, "big") 

140 with open(os.path.join(constants.SFKIT_DIR, "shared_key_global.bin"), "wb") as f: 

141 f.write(global_shared_key) 

142 

143 print(f"Shared keys generated and saved to {constants.SFKIT_DIR}.") 

144 

145 

146def update_config_local(role: str, protocol: str = "gwas") -> None: 

147 """ 

148 Update configLocal.Party{role}.toml 

149 :param role: 0, 1, 2, ... 

150 """ 

151 doc_ref_dict: dict = get_doc_ref_dict() 

152 if constants.BLOCKS_MODE in doc_ref_dict["description"]: 

153 use_existing_config(role, doc_ref_dict) 

154 return 

155 

156 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configLocal.Party{role}.toml" 

157 

158 try: 

159 with open(config_file_path, "r") as f: 

160 data = tomlkit.parse(f.read()) 

161 

162 except FileNotFoundError: 

163 print(f"File {config_file_path} not found.") 

164 print("Creating it...") 

165 shutil.copyfile( 

166 f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configLocal.Party2.toml", config_file_path 

167 ) 

168 with open(config_file_path, "r") as f: 

169 data = tomlkit.parse(f.read()) 

170 

171 if role != "0": 

172 update_data_file_paths(data) 

173 data["shared_keys_path"] = constants.SFKIT_DIR 

174 data["output_dir"] = f"out/party{role}" 

175 data["cache_dir"] = f"cache/party{role}" 

176 

177 doc_ref_dict = get_doc_ref_dict() 

178 user_id: str = doc_ref_dict["participants"][int(role)] 

179 data["local_num_threads"] = int(doc_ref_dict["personal_parameters"][user_id]["NUM_CPUS"]["value"]) 

180 data["assoc_num_blocks_parallel"] = int(data.get("local_num_threads", 16)) // 8 

181 data["memory_limit"] = int(int(data.get("local_num_threads", 16)) * 8 * 1_000_000_000) 

182 

183 with open(config_file_path, "w") as f: 

184 f.write(tomlkit.dumps(data)) 

185 

186 

187def update_data_file_paths(data: dict) -> None: 

188 geno_file_prefix, data_path = get_file_paths() 

189 

190 data["geno_binary_file_prefix"] = f"{geno_file_prefix}" 

191 data["geno_block_size_file"] = f"{data_path}/chrom_sizes.txt" 

192 data["pheno_file"] = f"{data_path}/pheno.txt" 

193 data["covar_file"] = f"{data_path}/cov.txt" 

194 data["snp_position_file"] = f"{data_path}/snp_pos.txt" 

195 data["sample_keep_file"] = f"{data_path}/sample_keep.txt" 

196 data["snp_ids_file"] = f"{data_path}/snp_ids.txt" 

197 data["geno_count_file"] = f"{data_path}/all.gcount.transpose.bin" 

198 

199 # don't need to return anything because data is a mutable object 

200 

201 

202def update_config_global(protocol: str = "gwas") -> None: 

203 """ 

204 Update configGlobal.toml 

205 """ 

206 print("Updating configGlobal.toml") 

207 doc_ref_dict: dict = get_doc_ref_dict() 

208 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configGlobal.toml" 

209 with open(config_file_path, "r") as f: 

210 data = tomlkit.parse(f.read()) 

211 

212 # Update the ip addresses and ports 

213 for i, participant in enumerate(doc_ref_dict["participants"]): 

214 if f"party{i}" not in data.get("servers", {}): 

215 data.get("servers", {})[f"party{i}"] = copy.deepcopy(data.get("servers", {})[f"party{i-1}"]) 

216 

217 ip_addr = doc_ref_dict["personal_parameters"][participant]["IP_ADDRESS"]["value"] 

218 data.get("servers", {}).get(f"party{i}", {})["ipaddr"] = ip_addr 

219 

220 ports: list = doc_ref_dict["personal_parameters"][participant]["PORTS"]["value"].split(",") 

221 for j, port in enumerate(ports): 

222 data.get("servers", {}).get(f"party{i}", {}).get("ports", {})[f"party{j}"] = port 

223 

224 if constants.BLOCKS_MODE not in doc_ref_dict["description"]: 

225 data["num_main_parties"] = len(doc_ref_dict["participants"]) - 1 

226 

227 row_name = "num_rows" if protocol == "pca" else "num_inds" 

228 col_name = "num_columns" if protocol == "pca" else "num_snps" 

229 data[row_name] = [] 

230 for i, participant in enumerate(doc_ref_dict["participants"]): 

231 data.get(row_name, []).append(int(doc_ref_dict["personal_parameters"][participant]["NUM_INDS"]["value"])) 

232 print(f"{row_name} for {participant} is {data.get(row_name, [])[i]}") 

233 condition_or_fail(i == 0 or data.get(row_name, [])[i] > 0, f"{row_name} must be greater than 0") 

234 data[col_name] = ( 

235 int(doc_ref_dict["parameters"]["num_snps"]["value"]) 

236 if protocol == "gwas" 

237 else int(doc_ref_dict["parameters"]["num_columns"]["value"]) 

238 ) 

239 print(f"{col_name} is {data[col_name]}") 

240 condition_or_fail(data.get(col_name, 0) > 0, f"{col_name} must be greater than 0") 

241 

242 # shared and advanced parameters 

243 pars = {**doc_ref_dict["parameters"], **doc_ref_dict["advanced_parameters"]} 

244 for key, value in pars.items(): 

245 if key in data: 

246 data[key] = to_float_int_or_bool(value["value"]) 

247 

248 with open(config_file_path, "w") as f: 

249 f.write(tomlkit.dumps(data)) 

250 

251 

252def update_config_global_phase(phase: str, demo: bool, protocol: str = "gwas") -> None: 

253 """ 

254 Update based on phase in configGlobal.toml 

255 :param phase: "1", "2", "3" 

256 """ 

257 config_file_path = f"{constants.EXECUTABLES_PREFIX}sfgwas/config/{protocol}/configGlobal.toml" 

258 with open(config_file_path, "r") as f: 

259 data = tomlkit.parse(f.read()) 

260 

261 data["phase"] = phase 

262 if phase == "2": 

263 data["use_cached_qc"] = True 

264 elif phase == "3": 

265 data["use_cached_qc"] = True 

266 data["use_cached_pca"] = True 

267 

268 if demo: 

269 data["num_power_iters"] = 2 

270 data["iter_per_eigenval"] = 2 

271 data["num_pcs_to_remove"] = 2 

272 

273 with open(config_file_path, "w") as f: 

274 f.write(tomlkit.dumps(data)) 

275 

276 

277def update_sfgwas_go(protocol: str = "gwas") -> None: 

278 """ 

279 Update sfgwas.go 

280 """ 

281 for line in fileinput.input(f"{constants.EXECUTABLES_PREFIX}sfgwas/sfgwas.go", inplace=True): 

282 if "CONFIG_PATH = " in line: 

283 print(f'var CONFIG_PATH = "config/{protocol}"') 

284 else: 

285 print(line, end="") 

286 

287 

288def build_sfgwas() -> None: 

289 """ 

290 build/compile sfgwas 

291 """ 

292 update_firestore("update_firestore::task=Compiling code") 

293 print("Building sfgwas code") 

294 command = """export PYTHONUNBUFFERED=TRUE && export PATH=$PATH:/usr/local/go/bin && export HOME=~ && export GOCACHE=~/.cache/go-build && cd sfgwas && go get -t github.com/hcholab/sfgwas && go build""" 

295 run_command(command) 

296 print("Finished building sfgwas code") 

297 

298 

299def start_sfgwas(role: str, demo: bool = False, protocol: str = "SF-GWAS") -> None: 

300 """ 

301 Start the actual sfgwas program 

302 :param role: 0, 1, 2, ... 

303 :param demo: True if running demo 

304 """ 

305 update_firestore("update_firestore::task=Initiating Protocol") 

306 print("Begin SF-GWAS protocol") 

307 protocol_command = f"export PID={role} && go run sfgwas.go | tee stdout_party{role}.txt" 

308 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT: 

309 protocol_command = f"cd {constants.EXECUTABLES_PREFIX}sfgwas && PID={role} sfgwas | tee stdout_party{role}.txt" 

310 if demo: 

311 protocol_command = "bash run_example.sh" 

312 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT: 

313 # cannot use "go run" from run_example.sh in Docker, so reproducing that script in Python here 

314 protocol_command = ( 

315 " & ".join( 

316 f"(cd {constants.EXECUTABLES_PREFIX}sfgwas && PID={r} sfgwas | tee stdout_party{r}.txt)" 

317 for r in range(3) 

318 ) 

319 + " & wait $(jobs -p)" 

320 ) 

321 command = f"export PYTHONUNBUFFERED=TRUE && export PATH=$PATH:/usr/local/go/bin && export HOME=~ && export GOCACHE=~/.cache/go-build && cd {constants.EXECUTABLES_PREFIX}sfgwas && {protocol_command}" 

322 if constants.IS_DOCKER or constants.IS_INSTALLED_VIA_SCRIPT: 

323 command = f"export PYTHONUNBUFFERED=TRUE && {protocol_command}" 

324 

325 run_sfgwas_with_task_updates(command, protocol, demo, role) 

326 print(f"Finished {protocol} protocol") 

327 

328 if role == "0": 

329 update_firestore("update_firestore::status=Finished protocol!") 

330 return 

331 

332 post_process_results(role, demo, protocol)