Coverage for sfkit/encryption/mpc/encrypt_data.py: 100%

112 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-07 15:11 -0400

1# for MPC-GWAS 

2 

3import os 

4import shutil 

5import sys 

6import time 

7 

8import checksumdir 

9import nacl.secret 

10import nacl.utils 

11from nacl.encoding import HexEncoder 

12from nacl.public import Box, PrivateKey, PublicKey 

13 

14from sfkit.api import get_doc_ref_dict, get_username 

15from sfkit.encryption.mpc.random_number_generator import PseudoRandomNumberGenerator 

16from sfkit.utils import constants 

17from sfkit.utils.helper_functions import condition_or_fail 

18 

19# from tqdm import tqdm 

20 

21 

22def encrypt_GMP(prng: PseudoRandomNumberGenerator, input_dir: str, output_dir: str) -> None: 

23 # sourcery skip: avoid-global-variables, avoid-single-character-names-variables, ensure-file-closed, require-parameter-annotation, snake-case-functions, switch 

24 """ 

25 Converts the data to GMP vectors (genotype, missing data, phenotype), encrypts 

26 them, and writes them to files. 

27 """ 

28 geno_file = open(f"{input_dir}/geno.txt", "r") 

29 pheno_file = open(f"{input_dir}/pheno.txt", "r") 

30 cov_file = open(f"{input_dir}/cov.txt", "r") 

31 

32 # make directory for encrypted data if it doesn't exist 

33 if not os.path.exists(output_dir): 

34 os.makedirs(output_dir) 

35 

36 g_file = open(f"{output_dir}/g.bin", "wb") 

37 m_file = open(f"{output_dir}/m.bin", "wb") 

38 p_file = open(f"{output_dir}/p.bin", "wb") 

39 

40 num_lines = sum(1 for _ in open(f"{input_dir}/pheno.txt", "r")) 

41 for i in range(num_lines): 

42 if i > 0 and num_lines > 10 and i % (num_lines // 10) == 0: 

43 print(f"Finished encrypting {i} lines out of {num_lines} lines") 

44 

45 p = pheno_file.readline().rstrip().split() + cov_file.readline().rstrip().split() 

46 p = [str((int(x) - prng.next()) % prng.base_p) for x in p] 

47 

48 geno_line = geno_file.readline().rstrip().split() 

49 g = [[-prng.next() % prng.base_p for _ in range(len(geno_line))] for _ in range(3)] 

50 m = [-prng.next() % prng.base_p for _ in range(len(geno_line))] 

51 for j, val in enumerate(geno_line): 

52 if val == "0": 

53 g[0][j] = (g[0][j] + 1) % prng.base_p 

54 elif val == "1": 

55 g[1][j] = (g[1][j] + 1) % prng.base_p 

56 elif val == "2": 

57 g[2][j] = (g[2][j] + 1) % prng.base_p 

58 else: 

59 m[j] = (m[j] + 1) % prng.base_p 

60 

61 g_text = " ".join(map(str, g[0])) + " " + " ".join(map(str, g[1])) + " " + " ".join(map(str, g[2])) + "\n" 

62 g_file.write(g_text.encode("utf-8")) 

63 m_file.write((" ".join([str(x) for x in m]) + "\n").encode("utf-8")) 

64 p_file.write((" ".join(p) + "\n").encode("utf-8")) 

65 

66 geno_file.close() 

67 pheno_file.close() 

68 cov_file.close() 

69 

70 g_file.close() 

71 m_file.close() 

72 p_file.close() 

73 

74 

75def get_shared_mpcgwas_keys(my_private_key: PrivateKey, other_public_key: PublicKey, debug: bool = False) -> list: 

76 """ 

77 Given the private key of the user and the public key of the other user, generate 2 shared keys. 

78 The first is for the user with role 1 and the second is for the user with role 2. 

79 """ 

80 shared_key_prototype = Box(my_private_key, other_public_key) 

81 if debug: 

82 print( 

83 "Shared key prototype: ", 

84 shared_key_prototype.shared_key().hex(), # type: ignore 

85 ) 

86 box = nacl.secret.SecretBox(shared_key_prototype.shared_key()) 

87 nonce = 1 

88 message = b"bqvbiknychqjywxwjihfrfhgroxycxxj" # some arbitrary 32 letter string 

89 if debug: 

90 print("nonce in hex: ", nonce.to_bytes(24, byteorder="big").hex()) 

91 print("Byte message: ", message) 

92 shared_key_1 = box.encrypt(message, nonce=nonce.to_bytes(24, byteorder="big")).ciphertext[16:] 

93 # python library has extra 16 bytes at beginning; see https://doc.libsodium.org/secret-key_cryptography/secretbox#notes 

94 nonce += 1 

95 if debug: 

96 print("nonce in hex: ", nonce.to_bytes(24, byteorder="big").hex()) 

97 shared_key_2 = box.encrypt(message, nonce=nonce.to_bytes(24, byteorder="big")).ciphertext[16:] 

98 

99 if debug: 

100 print("Role 1's Key:", shared_key_1.hex()) 

101 print("Role 2's Key:", shared_key_2.hex()) 

102 

103 return [None, shared_key_1, shared_key_2] 

104 

105 

106def encrypt_data() -> None: 

107 username: str = get_username() 

108 doc_ref_dict: dict = get_doc_ref_dict() 

109 role = doc_ref_dict["participants"].index(username) 

110 

111 other_public_key_str = get_other_user_public_key(doc_ref_dict, role) 

112 

113 other_public_key = PublicKey(other_public_key_str, encoder=HexEncoder) # type: ignore 

114 

115 print("Generating shared keys...") 

116 private_key_path = os.path.join(constants.SFKIT_DIR, "my_private_key.txt") 

117 with open(private_key_path, "r") as f: 

118 my_private_key = PrivateKey(f.readline().rstrip(), encoder=HexEncoder) # type: ignore 

119 condition_or_fail(my_private_key != other_public_key, "Private and public keys must be different") 

120 

121 shared_keys = get_shared_mpcgwas_keys(my_private_key, other_public_key) 

122 

123 input_dir_path = os.path.join(constants.SFKIT_DIR, "data_path.txt") 

124 with open(input_dir_path, "r") as f: 

125 input_dir = f.readline().rstrip() 

126 

127 data_hash = checksumdir.dirhash(input_dir, "md5") 

128 condition_or_fail( 

129 data_hash == doc_ref_dict["personal_parameters"][username]["DATA_HASH"]["value"], "Data hash mismatch" 

130 ) 

131 

132 print("Encrypting data...") 

133 base_p: int = int(doc_ref_dict["advanced_parameters"]["BASE_P"]["value"]) 

134 encrypt_GMP( 

135 PseudoRandomNumberGenerator(shared_keys[role], base_p), input_dir, output_dir=constants.ENCRYPTED_DATA_FOLDER 

136 ) 

137 

138 print("saving shared key") 

139 with open(os.path.join(constants.ENCRYPTED_DATA_FOLDER, "other_shared_key.bin"), "wb") as f: 

140 f.write(shared_keys[3 - role]) 

141 print("copying over pos.txt") 

142 shutil.copyfile(f"{input_dir}/pos.txt", os.path.join(constants.ENCRYPTED_DATA_FOLDER, "pos.txt")) 

143 

144 print("\n\nThe encryption is complete.") 

145 

146 

147def get_other_user_public_key(doc_ref_dict: dict, role: int) -> str: 

148 print("Downloading other party's public key...") 

149 

150 other_role = 3 - role 

151 

152 if len(doc_ref_dict["participants"]) != 3: 

153 print("Expected 2 participants (excluding Broad). Exiting.") 

154 sys.exit(1) 

155 

156 i = 0 

157 while True: 

158 other_public_key_str = ( 

159 doc_ref_dict["personal_parameters"][doc_ref_dict["participants"][other_role]] 

160 .get("PUBLIC_KEY", {}) 

161 .get("value") 

162 ) 

163 if other_public_key_str: 

164 break 

165 

166 print("No public key found for other user. Waiting...") 

167 i += 1 

168 if i > 60: 

169 print("Failed to find public key for other user after 5 minutes. Exiting.") 

170 sys.exit(1) 

171 

172 time.sleep(5) 

173 doc_ref_dict = get_doc_ref_dict() 

174 

175 return other_public_key_str