Coverage for kye/validate.py: 25%

67 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-13 15:17 -0700

1from kye.dataset import Models, Type, Edge, TYPE_REF, EDGE 

2from kye.loader.loader import Loader, struct_pack 

3from duckdb import DuckDBPyConnection, DuckDBPyRelation 

4 

5class Validate: 

6 loader: Loader 

7 tables: dict[TYPE_REF, DuckDBPyRelation] 

8 

9 def __init__(self, loader: Loader): 

10 self.loader = loader 

11 self.tables = {} 

12 

13 self.db.sql('CREATE TABLE errors (rule_ref TEXT, error_type TEXT, object_id UINT64, val JSON);') 

14 self.errors = self.db.table('errors') 

15 

16 for model_name, table in self.loader.tables.items(): 

17 table = self._validate_model(self.models[model_name], table) 

18 table_name = f'"{model_name}.validated"' 

19 table.create(table_name) 

20 self.tables[model_name] = self.db.table(table_name) 

21 

22 @property 

23 def db(self) -> DuckDBPyConnection: 

24 return self.loader.db 

25 

26 @property 

27 def models(self) -> Models: 

28 return self.loader.models 

29 

30 def _add_errors_where(self, r: DuckDBPyRelation, condition: str, rule_ref: str, error_type: str): 

31 err = r.filter(condition) 

32 err = err.select(f''' '{rule_ref}' as rule_ref, '{error_type}' as error_type, _index as object_id, to_json(val) as val''') 

33 err.insert_into('errors') 

34 return r.filter(f'''NOT ({condition})''') 

35 

36 def check_for_index_collision(self, typ: Type, r: DuckDBPyRelation): 

37 packed_indexes = ','.join(f"list_pack({','.join(sorted(index))})" for index in typ.indexes) 

38 r = r.select(f'''_index, UNNEST([{packed_indexes}]) as index_val''') 

39 r = r.aggregate('index_val, list_distinct(list(_index)) as _indexes') 

40 

41 r = r.select('index_val as val, unnest(_indexes) as _index, len(_indexes) > 1 as collision') 

42 

43 self._add_errors_where(r, 

44 condition = 'collision', 

45 rule_ref = typ.ref, 

46 error_type = 'NON_UNIQUE_INDEX' 

47 ) 

48 # Select the good indexes 

49 return r.aggregate('_index, bool_or(collision) as collision').filter('not collision').select('_index') 

50 

51 

52 def _validate_model(self, typ: Type, r: DuckDBPyRelation): 

53 edges = r.aggregate('_index') 

54 

55 # No need to check for conflicting indexes if there is only one 

56 if len(typ.indexes) > 1: 

57 edges = self.check_for_index_collision(typ, r) 

58 

59 for edge_name, edge in typ.edges.items(): 

60 edge_rel = r.select(f'''_index, {edge_name if edge_name in r.columns else 'CAST(NULL as VARCHAR)'} as val''') 

61 edge_rel = self._validate_edge(edge, edge_rel).set_alias(edge.ref) 

62 edge_rel = edge_rel.select(f'''_index, val as {edge_name}''') 

63 edges = edges.join(edge_rel, '_index', how='left') 

64 return edges 

65 

66 def _validate_edge(self, edge: Edge, r: DuckDBPyRelation): 

67 agg_fun = 'list_distinct(flatten(list(val)))' if r.val.dtypes[0].id == 'list' else 'list_distinct(list(val))' 

68 r = r.aggregate(f'''_index, {agg_fun} as val''') 

69 

70 if not edge.nullable: 

71 r = self._add_errors_where(r, 'len(val) == 0', edge.ref, 'NOT_NULLABLE') 

72 

73 if not edge.multiple: 

74 r = self._add_errors_where(r, 'len(val) > 1', edge.ref, 'NOT_MULTIPLE') 

75 r = r.select(f'''_index, val[1] as val''') 

76 else: 

77 r = r.select(f'''_index, unnest(val) as val''') 

78 

79 r = r.filter('val IS NOT NULL') 

80 r = self._validate_value(edge.type, r) 

81 

82 if edge.multiple: 

83 r = r.aggregate('_index, list(val) as val') 

84 

85 return r 

86 

87 def _validate_value(self, typ: Type, r: DuckDBPyRelation): 

88 # TODO: Look up object references and see if they exist 

89 

90 base_type = typ.base.name 

91 

92 if base_type == 'Boolean': 

93 r = self._add_errors_where(r, 'TRY_CAST(val as BOOLEAN) IS NULL', typ.ref, 'INVALID_VALUE') 

94 elif base_type == 'Number': 

95 r = self._add_errors_where(r, 'TRY_CAST(val AS DOUBLE) IS NULL', typ.ref, 'INVALID_VALUE') 

96 

97 return r 

98 

99 def __getitem__(self, model_name: TYPE_REF): 

100 return self.tables[model_name] 

101 

102 def __repr__(self): 

103 return f"<Validate {','.join(self.tables.keys())}>"