23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141 | class RedefineNetworkEdgesModule(KiaraModule):
"""Redefine edges by merging duplicate edges and applying aggregation functions to certain edge attributes."""
_module_type_name = "network_data.redefine_edges"
def create_inputs_schema(
self,
) -> ValueMapSchema:
result = {
"network_data": {
"type": "network_data",
"doc": "The network data to flatten.",
},
"attribute_map_strategies": {
"type": "kiara_model_list",
"type_config": {"kiara_model_id": "input.attribute_map_strategy"},
"doc": "A list of specs on how to map existing attributes onto the target network edge data.",
"optional": True,
},
}
return result
def create_outputs_schema(
self,
) -> ValueMapSchema:
result: Dict[str, Dict[str, Any]] = {}
result["network_data"] = {
"type": "network_data",
"doc": "The network_data, with a new column added to the nodes table, indicating the component the node belongs to.",
}
return result
def process(self, inputs: ValueMap, outputs: ValueMap):
import duckdb
import pyarrow as pa
network_data_obj = inputs.get_value_obj("network_data")
network_data: NetworkData = network_data_obj.data
edges_table = network_data.edges.arrow_table
attr_map_strategies: Union[
None, KiaraModelList[AttributeMapStrategy]
] = inputs.get_value_data("attribute_map_strategies")
if attr_map_strategies:
invalid_columns = set()
for strategy in attr_map_strategies.list_items:
if strategy.source_column_name == SOURCE_COLUMN_NAME:
raise KiaraProcessingException(
msg=f"Can't redefine edges with provided attribute map: the source column name '{SOURCE_COLUMN_NAME}' is reserved."
)
if strategy.source_column_name == TARGET_COLUMN_NAME:
raise KiaraProcessingException(
msg=f"Can't redefine edges with provided attribute map: the target column name '{TARGET_COLUMN_NAME}' is reserved."
)
if strategy.source_column_name not in network_data.edges.column_names:
invalid_columns.add(strategy.source_column_name)
if invalid_columns:
msg = f"Can't redefine edges with provided attribute map strategies: the following columns are not available in the network data: {', '.join(invalid_columns)}"
msg = f"{msg}\n\nAvailable column names:\n\n"
for col_name in (
x for x in network_data.edges.column_names if not x.startswith("_")
):
msg = f"{msg}\n - {col_name}"
raise KiaraProcessingException(msg=msg)
sql_tokens: List[str] = []
group_bys = [SOURCE_COLUMN_NAME, TARGET_COLUMN_NAME]
if attr_map_strategies:
for strategy in attr_map_strategies.list_items:
if not strategy.transform_function:
column_type = edges_table.field(strategy.source_column_name).type
if pa.types.is_integer(column_type) or pa.types.is_floating(
column_type
):
transform_function = "SUM"
else:
transform_function = "LIST"
else:
transform_function = strategy.transform_function
transform_function = transform_function.lower()
if transform_function == "group_by":
group_bys.append(strategy.source_column_name)
sql_token = None
elif transform_function == "string_agg_comma":
sql_token = f"STRING_AGG({strategy.source_column_name}, ',') as {strategy.target_column_name}"
else:
sql_token = f"{transform_function.upper()}({strategy.source_column_name}) as {strategy.target_column_name}"
if sql_token:
sql_tokens.append(sql_token)
query = f"""
SELECT
{', '.join(group_bys)},
{', '.join(sql_tokens)}
FROM edges_table
GROUP BY {', '.join(group_bys)}
"""
result = duckdb.sql(query)
new_edges_table = result.arrow()
network_data = NetworkData.create_network_data(
nodes_table=network_data.nodes.arrow_table,
edges_table=new_edges_table,
augment_tables=True,
)
outputs.set_values(network_data=network_data)
|