Coverage for kye/loader/loader.py: 24%
76 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-13 15:17 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-13 15:17 -0700
1import duckdb
2from duckdb import DuckDBPyRelation, DuckDBPyConnection
3from kye.dataset import Models
4from kye.loader.json_lines import from_json
5from kye.dataset import Type, DefinedType, Edge, TYPE_REF
8def append_table(con: DuckDBPyConnection, orig: DuckDBPyRelation, new: DuckDBPyRelation):
9 """
10 This function will not be needed in the future if we can figure out a standard way
11 to create the staging tables with the correct types before any data is uploaded.
12 """
14 def get_dtypes(r: DuckDBPyRelation):
15 return dict(zip(r.columns, r.dtypes))
17 orig_dtypes = get_dtypes(orig)
18 new_dtypes = get_dtypes(new)
20 # Check that the types of the columns match
21 for col in set(orig_dtypes) & set(new_dtypes):
22 if orig_dtypes[col] != new_dtypes[col]:
23 raise ValueError(f'''Column {col} has conflicting types: {orig_dtypes[col]} != {new_dtypes[col]}''')
25 # Alter the original table to include any new columns
26 for col in set(new_dtypes) - set(orig_dtypes):
27 con.sql(f'''ALTER TABLE "{orig.alias}" ADD COLUMN {col} {new_dtypes[col]}''')
29 # preserve the order of columns from the original table
30 # and cast any new columns to null
31 new = new.select(', '.join(
32 col if col in new_dtypes
33 else f'CAST(NULL as {orig_dtypes[col]}) as {col}'
34 for col in con.table(f'"{orig.alias}"').columns
35 ))
37 new.insert_into(f'"{orig.alias}"')
39def get_struct_keys(r: DuckDBPyRelation):
40 assert r.columns[1] == 'val'
41 assert r.dtypes[1].id == 'struct'
42 return [col[0] for col in r.dtypes[1].children]
44def struct_pack(edges: list[str], r: DuckDBPyRelation):
45 return 'struct_pack(' + ','.join(
46 f'''"{edge_name}":="{edge_name}"'''
47 for edge_name in edges
48 if edge_name in r.columns
49 ) + ')'
51def get_index(typ: Type, r: DuckDBPyRelation):
52 # Hash the index columns
53 r = r.select(f'''hash({struct_pack(sorted(typ.index), r)}) as _index, *''')
55 # Filter out null indexes
56 r = r.filter(f'''{' AND '.join(edge + ' IS NOT NULL' for edge in typ.index)}''')
57 return r
59class Loader:
60 """
61 The loader is responsible for normalizing the shape of the data. It makes sure that
62 all of the columns are present (filling in nulls where necessary) and also computes
63 the index hash for each row so that it is easy to join the data together later.
65 The loader operates for each chunk of the data while it is loading. So it does not
66 do any cross table aggregations or validation.
68 Any value normalization needs to be done here so that the index hash is consistent.
69 """
70 tables: dict[TYPE_REF, duckdb.DuckDBPyRelation]
72 def __init__(self, models: Models):
73 self.tables = {}
74 self.models = models
75 self.db = duckdb.connect(':memory:')
76 self.chunks = {}
78 def _insert(self, model_name: TYPE_REF, r: duckdb.DuckDBPyRelation):
79 table_name = f'"{model_name}.staging"'
80 if model_name not in self.tables:
81 r.create(table_name)
82 else:
83 append_table(self.db, self.tables[model_name], r)
84 self.tables[model_name] = self.db.table(table_name)
86 def _load(self, typ: DefinedType, r: duckdb.DuckDBPyRelation):
87 chunk_id = typ.name + '_' + str(len(self.chunks) + 1)
88 chunk = r.select(f'''list_value('{chunk_id}', ROW_NUMBER() OVER () - 1) as _, {struct_pack(typ.edges.keys(), r)} as val''').set_alias(chunk_id)
89 self.chunks[chunk_id] = chunk
90 self._get_value(typ, chunk)
92 def _get_value(self, typ: Type, r: DuckDBPyRelation):
93 if typ.has_edges:
94 edges = r.select('_')
95 for edge_name, edge in typ.edges.items():
96 if edge_name in get_struct_keys(r):
97 edge_rel = self._get_edge(edge, r.select(f'''list_append(_, '{edge_name}') as _, val.{edge_name} as val''')).set_alias(edge.ref)
98 edge_rel = edge_rel.select(f'''array_pop_back(_) as _, val as {edge_name}''')
99 edges = edges.join(edge_rel, '_', how='left')
101 if typ.has_index:
102 edges = get_index(typ, edges)
103 self._insert(typ.ref, edges)
104 return edges.select(f'''_, _index as val''')
105 else:
106 return edges.select(f'''_, {struct_pack(typ.edges.keys(), edges)} as val''')
108 elif r.dtypes[1].id != 'varchar':
109 dtype = r.dtypes[1].id
110 r = r.select(f'''_, CAST(val AS VARCHAR) as val''')
111 # remove trailing '.0' from decimals so that
112 # they will match integers of the same value
113 if dtype in ['double','decimal','real']:
114 r = r.select(f'''_, REGEXP_REPLACE(val, '\\.0$', '') as val''')
115 return r
117 def _get_edge(self, edge: Edge, r: DuckDBPyRelation):
118 if edge.multiple:
119 r = r.select('''_, unnest(val) as val''').select('list_append(_, ROW_NUMBER() OVER (PARTITION BY _) - 1) as _, val')
121 r = self._get_value(edge.type, r)
123 if edge.multiple:
124 r = r.aggregate('array_pop_back(_) as _, list(val) as val','array_pop_back(_)')
126 return r
128 def from_json(self, model_name: TYPE_REF, data: list[dict]):
129 r = from_json(self.models[model_name], data, self.db)
130 self._load(self.models[model_name], r)
132 def __getitem__(self, model_name: str):
133 return self.tables[model_name]
135 def __repr__(self):
136 return f"<Loader {','.join(self.tables.keys())}>"