Coverage for sfkit/encryption/mpc/encrypt_data.py: 100%
112 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 15:11 -0400
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 15:11 -0400
1# for MPC-GWAS
3import os
4import shutil
5import sys
6import time
8import checksumdir
9import nacl.secret
10import nacl.utils
11from nacl.encoding import HexEncoder
12from nacl.public import Box, PrivateKey, PublicKey
14from sfkit.api import get_doc_ref_dict, get_username
15from sfkit.encryption.mpc.random_number_generator import PseudoRandomNumberGenerator
16from sfkit.utils import constants
17from sfkit.utils.helper_functions import condition_or_fail
19# from tqdm import tqdm
22def encrypt_GMP(prng: PseudoRandomNumberGenerator, input_dir: str, output_dir: str) -> None:
23 # sourcery skip: avoid-global-variables, avoid-single-character-names-variables, ensure-file-closed, require-parameter-annotation, snake-case-functions, switch
24 """
25 Converts the data to GMP vectors (genotype, missing data, phenotype), encrypts
26 them, and writes them to files.
27 """
28 geno_file = open(f"{input_dir}/geno.txt", "r")
29 pheno_file = open(f"{input_dir}/pheno.txt", "r")
30 cov_file = open(f"{input_dir}/cov.txt", "r")
32 # make directory for encrypted data if it doesn't exist
33 if not os.path.exists(output_dir):
34 os.makedirs(output_dir)
36 g_file = open(f"{output_dir}/g.bin", "wb")
37 m_file = open(f"{output_dir}/m.bin", "wb")
38 p_file = open(f"{output_dir}/p.bin", "wb")
40 num_lines = sum(1 for _ in open(f"{input_dir}/pheno.txt", "r"))
41 for i in range(num_lines):
42 if i > 0 and num_lines > 10 and i % (num_lines // 10) == 0:
43 print(f"Finished encrypting {i} lines out of {num_lines} lines")
45 p = pheno_file.readline().rstrip().split() + cov_file.readline().rstrip().split()
46 p = [str((int(x) - prng.next()) % prng.base_p) for x in p]
48 geno_line = geno_file.readline().rstrip().split()
49 g = [[-prng.next() % prng.base_p for _ in range(len(geno_line))] for _ in range(3)]
50 m = [-prng.next() % prng.base_p for _ in range(len(geno_line))]
51 for j, val in enumerate(geno_line):
52 if val == "0":
53 g[0][j] = (g[0][j] + 1) % prng.base_p
54 elif val == "1":
55 g[1][j] = (g[1][j] + 1) % prng.base_p
56 elif val == "2":
57 g[2][j] = (g[2][j] + 1) % prng.base_p
58 else:
59 m[j] = (m[j] + 1) % prng.base_p
61 g_text = " ".join(map(str, g[0])) + " " + " ".join(map(str, g[1])) + " " + " ".join(map(str, g[2])) + "\n"
62 g_file.write(g_text.encode("utf-8"))
63 m_file.write((" ".join([str(x) for x in m]) + "\n").encode("utf-8"))
64 p_file.write((" ".join(p) + "\n").encode("utf-8"))
66 geno_file.close()
67 pheno_file.close()
68 cov_file.close()
70 g_file.close()
71 m_file.close()
72 p_file.close()
75def get_shared_mpcgwas_keys(my_private_key: PrivateKey, other_public_key: PublicKey, debug: bool = False) -> list:
76 """
77 Given the private key of the user and the public key of the other user, generate 2 shared keys.
78 The first is for the user with role 1 and the second is for the user with role 2.
79 """
80 shared_key_prototype = Box(my_private_key, other_public_key)
81 if debug:
82 print(
83 "Shared key prototype: ",
84 shared_key_prototype.shared_key().hex(), # type: ignore
85 )
86 box = nacl.secret.SecretBox(shared_key_prototype.shared_key())
87 nonce = 1
88 message = b"bqvbiknychqjywxwjihfrfhgroxycxxj" # some arbitrary 32 letter string
89 if debug:
90 print("nonce in hex: ", nonce.to_bytes(24, byteorder="big").hex())
91 print("Byte message: ", message)
92 shared_key_1 = box.encrypt(message, nonce=nonce.to_bytes(24, byteorder="big")).ciphertext[16:]
93 # python library has extra 16 bytes at beginning; see https://doc.libsodium.org/secret-key_cryptography/secretbox#notes
94 nonce += 1
95 if debug:
96 print("nonce in hex: ", nonce.to_bytes(24, byteorder="big").hex())
97 shared_key_2 = box.encrypt(message, nonce=nonce.to_bytes(24, byteorder="big")).ciphertext[16:]
99 if debug:
100 print("Role 1's Key:", shared_key_1.hex())
101 print("Role 2's Key:", shared_key_2.hex())
103 return [None, shared_key_1, shared_key_2]
106def encrypt_data() -> None:
107 username: str = get_username()
108 doc_ref_dict: dict = get_doc_ref_dict()
109 role = doc_ref_dict["participants"].index(username)
111 other_public_key_str = get_other_user_public_key(doc_ref_dict, role)
113 other_public_key = PublicKey(other_public_key_str, encoder=HexEncoder) # type: ignore
115 print("Generating shared keys...")
116 private_key_path = os.path.join(constants.SFKIT_DIR, "my_private_key.txt")
117 with open(private_key_path, "r") as f:
118 my_private_key = PrivateKey(f.readline().rstrip(), encoder=HexEncoder) # type: ignore
119 condition_or_fail(my_private_key != other_public_key, "Private and public keys must be different")
121 shared_keys = get_shared_mpcgwas_keys(my_private_key, other_public_key)
123 input_dir_path = os.path.join(constants.SFKIT_DIR, "data_path.txt")
124 with open(input_dir_path, "r") as f:
125 input_dir = f.readline().rstrip()
127 data_hash = checksumdir.dirhash(input_dir, "md5")
128 condition_or_fail(
129 data_hash == doc_ref_dict["personal_parameters"][username]["DATA_HASH"]["value"], "Data hash mismatch"
130 )
132 print("Encrypting data...")
133 base_p: int = int(doc_ref_dict["advanced_parameters"]["BASE_P"]["value"])
134 encrypt_GMP(
135 PseudoRandomNumberGenerator(shared_keys[role], base_p), input_dir, output_dir=constants.ENCRYPTED_DATA_FOLDER
136 )
138 print("saving shared key")
139 with open(os.path.join(constants.ENCRYPTED_DATA_FOLDER, "other_shared_key.bin"), "wb") as f:
140 f.write(shared_keys[3 - role])
141 print("copying over pos.txt")
142 shutil.copyfile(f"{input_dir}/pos.txt", os.path.join(constants.ENCRYPTED_DATA_FOLDER, "pos.txt"))
144 print("\n\nThe encryption is complete.")
147def get_other_user_public_key(doc_ref_dict: dict, role: int) -> str:
148 print("Downloading other party's public key...")
150 other_role = 3 - role
152 if len(doc_ref_dict["participants"]) != 3:
153 print("Expected 2 participants (excluding Broad). Exiting.")
154 sys.exit(1)
156 i = 0
157 while True:
158 other_public_key_str = (
159 doc_ref_dict["personal_parameters"][doc_ref_dict["participants"][other_role]]
160 .get("PUBLIC_KEY", {})
161 .get("value")
162 )
163 if other_public_key_str:
164 break
166 print("No public key found for other user. Waiting...")
167 i += 1
168 if i > 60:
169 print("Failed to find public key for other user after 5 minutes. Exiting.")
170 sys.exit(1)
172 time.sleep(5)
173 doc_ref_dict = get_doc_ref_dict()
175 return other_public_key_str