Coverage for kye/validate.py: 25%
67 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-13 15:17 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-13 15:17 -0700
1from kye.dataset import Models, Type, Edge, TYPE_REF, EDGE
2from kye.loader.loader import Loader, struct_pack
3from duckdb import DuckDBPyConnection, DuckDBPyRelation
5class Validate:
6 loader: Loader
7 tables: dict[TYPE_REF, DuckDBPyRelation]
9 def __init__(self, loader: Loader):
10 self.loader = loader
11 self.tables = {}
13 self.db.sql('CREATE TABLE errors (rule_ref TEXT, error_type TEXT, object_id UINT64, val JSON);')
14 self.errors = self.db.table('errors')
16 for model_name, table in self.loader.tables.items():
17 table = self._validate_model(self.models[model_name], table)
18 table_name = f'"{model_name}.validated"'
19 table.create(table_name)
20 self.tables[model_name] = self.db.table(table_name)
22 @property
23 def db(self) -> DuckDBPyConnection:
24 return self.loader.db
26 @property
27 def models(self) -> Models:
28 return self.loader.models
30 def _add_errors_where(self, r: DuckDBPyRelation, condition: str, rule_ref: str, error_type: str):
31 err = r.filter(condition)
32 err = err.select(f''' '{rule_ref}' as rule_ref, '{error_type}' as error_type, _index as object_id, to_json(val) as val''')
33 err.insert_into('errors')
34 return r.filter(f'''NOT ({condition})''')
36 def check_for_index_collision(self, typ: Type, r: DuckDBPyRelation):
37 packed_indexes = ','.join(f"list_pack({','.join(sorted(index))})" for index in typ.indexes)
38 r = r.select(f'''_index, UNNEST([{packed_indexes}]) as index_val''')
39 r = r.aggregate('index_val, list_distinct(list(_index)) as _indexes')
41 r = r.select('index_val as val, unnest(_indexes) as _index, len(_indexes) > 1 as collision')
43 self._add_errors_where(r,
44 condition = 'collision',
45 rule_ref = typ.ref,
46 error_type = 'NON_UNIQUE_INDEX'
47 )
48 # Select the good indexes
49 return r.aggregate('_index, bool_or(collision) as collision').filter('not collision').select('_index')
52 def _validate_model(self, typ: Type, r: DuckDBPyRelation):
53 edges = r.aggregate('_index')
55 # No need to check for conflicting indexes if there is only one
56 if len(typ.indexes) > 1:
57 edges = self.check_for_index_collision(typ, r)
59 for edge_name, edge in typ.edges.items():
60 edge_rel = r.select(f'''_index, {edge_name if edge_name in r.columns else 'CAST(NULL as VARCHAR)'} as val''')
61 edge_rel = self._validate_edge(edge, edge_rel).set_alias(edge.ref)
62 edge_rel = edge_rel.select(f'''_index, val as {edge_name}''')
63 edges = edges.join(edge_rel, '_index', how='left')
64 return edges
66 def _validate_edge(self, edge: Edge, r: DuckDBPyRelation):
67 agg_fun = 'list_distinct(flatten(list(val)))' if r.val.dtypes[0].id == 'list' else 'list_distinct(list(val))'
68 r = r.aggregate(f'''_index, {agg_fun} as val''')
70 if not edge.nullable:
71 r = self._add_errors_where(r, 'len(val) == 0', edge.ref, 'NOT_NULLABLE')
73 if not edge.multiple:
74 r = self._add_errors_where(r, 'len(val) > 1', edge.ref, 'NOT_MULTIPLE')
75 r = r.select(f'''_index, val[1] as val''')
76 else:
77 r = r.select(f'''_index, unnest(val) as val''')
79 r = r.filter('val IS NOT NULL')
80 r = self._validate_value(edge.type, r)
82 if edge.multiple:
83 r = r.aggregate('_index, list(val) as val')
85 return r
87 def _validate_value(self, typ: Type, r: DuckDBPyRelation):
88 # TODO: Look up object references and see if they exist
90 base_type = typ.base.name
92 if base_type == 'Boolean':
93 r = self._add_errors_where(r, 'TRY_CAST(val as BOOLEAN) IS NULL', typ.ref, 'INVALID_VALUE')
94 elif base_type == 'Number':
95 r = self._add_errors_where(r, 'TRY_CAST(val AS DOUBLE) IS NULL', typ.ref, 'INVALID_VALUE')
97 return r
99 def __getitem__(self, model_name: TYPE_REF):
100 return self.tables[model_name]
102 def __repr__(self):
103 return f"<Validate {','.join(self.tables.keys())}>"