Coverage for kye/loader/loader.py: 24%

76 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-13 15:17 -0700

1import duckdb 

2from duckdb import DuckDBPyRelation, DuckDBPyConnection 

3from kye.dataset import Models 

4from kye.loader.json_lines import from_json 

5from kye.dataset import Type, DefinedType, Edge, TYPE_REF 

6 

7 

8def append_table(con: DuckDBPyConnection, orig: DuckDBPyRelation, new: DuckDBPyRelation): 

9 """ 

10 This function will not be needed in the future if we can figure out a standard way 

11 to create the staging tables with the correct types before any data is uploaded. 

12 """ 

13 

14 def get_dtypes(r: DuckDBPyRelation): 

15 return dict(zip(r.columns, r.dtypes)) 

16 

17 orig_dtypes = get_dtypes(orig) 

18 new_dtypes = get_dtypes(new) 

19 

20 # Check that the types of the columns match 

21 for col in set(orig_dtypes) & set(new_dtypes): 

22 if orig_dtypes[col] != new_dtypes[col]: 

23 raise ValueError(f'''Column {col} has conflicting types: {orig_dtypes[col]} != {new_dtypes[col]}''') 

24 

25 # Alter the original table to include any new columns 

26 for col in set(new_dtypes) - set(orig_dtypes): 

27 con.sql(f'''ALTER TABLE "{orig.alias}" ADD COLUMN {col} {new_dtypes[col]}''') 

28 

29 # preserve the order of columns from the original table 

30 # and cast any new columns to null 

31 new = new.select(', '.join( 

32 col if col in new_dtypes 

33 else f'CAST(NULL as {orig_dtypes[col]}) as {col}' 

34 for col in con.table(f'"{orig.alias}"').columns 

35 )) 

36 

37 new.insert_into(f'"{orig.alias}"') 

38 

39def get_struct_keys(r: DuckDBPyRelation): 

40 assert r.columns[1] == 'val' 

41 assert r.dtypes[1].id == 'struct' 

42 return [col[0] for col in r.dtypes[1].children] 

43 

44def struct_pack(edges: list[str], r: DuckDBPyRelation): 

45 return 'struct_pack(' + ','.join( 

46 f'''"{edge_name}":="{edge_name}"''' 

47 for edge_name in edges 

48 if edge_name in r.columns 

49 ) + ')' 

50 

51def get_index(typ: Type, r: DuckDBPyRelation): 

52 # Hash the index columns 

53 r = r.select(f'''hash({struct_pack(sorted(typ.index), r)}) as _index, *''') 

54 

55 # Filter out null indexes 

56 r = r.filter(f'''{' AND '.join(edge + ' IS NOT NULL' for edge in typ.index)}''') 

57 return r 

58 

59class Loader: 

60 """ 

61 The loader is responsible for normalizing the shape of the data. It makes sure that 

62 all of the columns are present (filling in nulls where necessary) and also computes 

63 the index hash for each row so that it is easy to join the data together later. 

64  

65 The loader operates for each chunk of the data while it is loading. So it does not 

66 do any cross table aggregations or validation. 

67 

68 Any value normalization needs to be done here so that the index hash is consistent. 

69 """ 

70 tables: dict[TYPE_REF, duckdb.DuckDBPyRelation] 

71 

72 def __init__(self, models: Models): 

73 self.tables = {} 

74 self.models = models 

75 self.db = duckdb.connect(':memory:') 

76 self.chunks = {} 

77 

78 def _insert(self, model_name: TYPE_REF, r: duckdb.DuckDBPyRelation): 

79 table_name = f'"{model_name}.staging"' 

80 if model_name not in self.tables: 

81 r.create(table_name) 

82 else: 

83 append_table(self.db, self.tables[model_name], r) 

84 self.tables[model_name] = self.db.table(table_name) 

85 

86 def _load(self, typ: DefinedType, r: duckdb.DuckDBPyRelation): 

87 chunk_id = typ.name + '_' + str(len(self.chunks) + 1) 

88 chunk = r.select(f'''list_value('{chunk_id}', ROW_NUMBER() OVER () - 1) as _, {struct_pack(typ.edges.keys(), r)} as val''').set_alias(chunk_id) 

89 self.chunks[chunk_id] = chunk 

90 self._get_value(typ, chunk) 

91 

92 def _get_value(self, typ: Type, r: DuckDBPyRelation): 

93 if typ.has_edges: 

94 edges = r.select('_') 

95 for edge_name, edge in typ.edges.items(): 

96 if edge_name in get_struct_keys(r): 

97 edge_rel = self._get_edge(edge, r.select(f'''list_append(_, '{edge_name}') as _, val.{edge_name} as val''')).set_alias(edge.ref) 

98 edge_rel = edge_rel.select(f'''array_pop_back(_) as _, val as {edge_name}''') 

99 edges = edges.join(edge_rel, '_', how='left') 

100 

101 if typ.has_index: 

102 edges = get_index(typ, edges) 

103 self._insert(typ.ref, edges) 

104 return edges.select(f'''_, _index as val''') 

105 else: 

106 return edges.select(f'''_, {struct_pack(typ.edges.keys(), edges)} as val''') 

107 

108 elif r.dtypes[1].id != 'varchar': 

109 dtype = r.dtypes[1].id 

110 r = r.select(f'''_, CAST(val AS VARCHAR) as val''') 

111 # remove trailing '.0' from decimals so that 

112 # they will match integers of the same value 

113 if dtype in ['double','decimal','real']: 

114 r = r.select(f'''_, REGEXP_REPLACE(val, '\\.0$', '') as val''') 

115 return r 

116 

117 def _get_edge(self, edge: Edge, r: DuckDBPyRelation): 

118 if edge.multiple: 

119 r = r.select('''_, unnest(val) as val''').select('list_append(_, ROW_NUMBER() OVER (PARTITION BY _) - 1) as _, val') 

120 

121 r = self._get_value(edge.type, r) 

122 

123 if edge.multiple: 

124 r = r.aggregate('array_pop_back(_) as _, list(val) as val','array_pop_back(_)') 

125 

126 return r 

127 

128 def from_json(self, model_name: TYPE_REF, data: list[dict]): 

129 r = from_json(self.models[model_name], data, self.db) 

130 self._load(self.models[model_name], r) 

131 

132 def __getitem__(self, model_name: str): 

133 return self.tables[model_name] 

134 

135 def __repr__(self): 

136 return f"<Loader {','.join(self.tables.keys())}>"