models
This module contains the metadata (and other) models that are used in the kiara_plugin.network_analysis
package.
Those models are convenience wrappers that make it easier for kiara to find, create, manage and version metadata -- but also other type of models -- that is attached to data, as well as kiara modules.
Metadata models must be a sub-class of [kiara.metadata.MetadataModel][]. Other models usually sub-class a pydantic BaseModel or implement custom base classes.
Classes¶
GraphType (Enum)
¶
All possible graph types.
Source code in network_analysis/models.py
class GraphType(Enum):
"""All possible graph types."""
UNDIRECTED = "undirected"
DIRECTED = "directed"
UNDIRECTED_MULTI = "undirected-multi"
DIRECTED_MULTI = "directed-multi"
GraphTypesEnum (Enum)
¶
NetworkData (KiaraDatabase)
pydantic-model
¶
A helper class to access and query network datasets.
This class provides different ways to access the underlying network data, most notably via sql and as networkx Graph object.
Internally, network data is stored in a sqlite database with the edges stored in a table called 'edges' and the nodes, well, in a table aptly called 'nodes'.
Source code in network_analysis/models.py
class NetworkData(KiaraDatabase):
"""A helper class to access and query network datasets.
This class provides different ways to access the underlying network data, most notably via sql and as networkx Graph object.
Internally, network data is stored in a sqlite database with the edges stored in a table called 'edges' and the nodes, well,
in a table aptly called 'nodes'.
"""
_kiara_model_id = "instance.network_data"
@classmethod
def create_from_networkx_graph(cls, graph: "nx.Graph") -> "NetworkData":
"""Create a `NetworkData` instance from a networkx Graph object."""
edges_table = extract_edges_as_table(graph)
edges_schema = create_sqlite_schema_data_from_arrow_table(edges_table)
nodes_table = extract_nodes_as_table(graph)
nodes_schema = create_sqlite_schema_data_from_arrow_table(nodes_table)
network_data = NetworkData.create_in_temp_dir(
edges_schema=edges_schema, nodes_schema=nodes_schema, keep_unlocked=True
)
insert_table_data_into_network_graph(
network_data=network_data,
edges_table=edges_table,
nodes_table=nodes_table,
chunk_size=DEFAULT_NETWORK_DATA_CHUNK_SIZE,
)
network_data._lock_db()
return network_data
@classmethod
def create_in_temp_dir(
cls,
edges_schema: Union[None, SqliteTableSchema, Mapping] = None,
nodes_schema: Union[None, SqliteTableSchema, Mapping] = None,
keep_unlocked: bool = False,
):
temp_f = tempfile.mkdtemp()
db_path = os.path.join(temp_f, "network_data.sqlite")
def cleanup():
shutil.rmtree(db_path, ignore_errors=True)
atexit.register(cleanup)
db = cls(
db_file_path=db_path, edges_schema=edges_schema, nodes_schema=nodes_schema
)
db.create_if_not_exists()
db._unlock_db()
engine = db.get_sqlalchemy_engine()
db.edges_schema.create_table(
table_name=NetworkDataTableType.EDGES.value, engine=engine
)
db.nodes_schema.create_table(
table_name=NetworkDataTableType.NODES.value, engine=engine
)
if not keep_unlocked:
db._lock_db()
return db
edges_schema: SqliteTableSchema = Field(
description="The schema information for the edges table."
)
nodes_schema: SqliteTableSchema = Field(
description="The schema information for the nodes table."
)
@root_validator(pre=True)
def pre_validate(cls, values):
_edges_schema = values.get("edges_schema", None)
_nodes_schema = values.get("nodes_schema", None)
if _edges_schema is None:
suggested_id_type = "TEXT"
if _nodes_schema is not None:
if isinstance(_nodes_schema, Mapping):
suggested_id_type = _nodes_schema.get(ID_COLUMN_NAME, "TEXT")
elif isinstance(_nodes_schema, SqliteTableSchema):
suggested_id_type = _nodes_schema.columns.get(
ID_COLUMN_NAME, "TEXT"
)
edges_schema = SqliteTableSchema.construct(
columns={
SOURCE_COLUMN_NAME: suggested_id_type,
TARGET_COLUMN_NAME: suggested_id_type,
}
)
else:
if isinstance(_edges_schema, Mapping):
edges_schema = SqliteTableSchema(**_edges_schema)
elif not isinstance(_edges_schema, SqliteTableSchema):
raise ValueError(
f"Invalid data type for edges schema: {type(_edges_schema)}"
)
else:
edges_schema = _edges_schema
if (
edges_schema.columns[SOURCE_COLUMN_NAME]
!= edges_schema.columns[TARGET_COLUMN_NAME]
):
raise ValueError(
f"Invalid edges schema, source and edges columns have different type: {edges_schema[SOURCE_COLUMN_NAME]} != {edges_schema[TARGET_COLUMN_NAME]}"
)
if _nodes_schema is None:
_nodes_schema = SqliteTableSchema.construct(
columns={
ID_COLUMN_NAME: edges_schema.columns[SOURCE_COLUMN_NAME],
LABEL_COLUMN_NAME: "TEXT",
}
)
if isinstance(_nodes_schema, Mapping):
nodes_schema = SqliteTableSchema(**_nodes_schema)
elif isinstance(_nodes_schema, SqliteTableSchema):
nodes_schema = _nodes_schema
else:
raise ValueError(
f"Invalid data type for nodes schema: {type(_edges_schema)}"
)
if ID_COLUMN_NAME not in nodes_schema.columns.keys():
raise ValueError(
f"Invalid nodes schema: missing '{ID_COLUMN_NAME}' column."
)
if LABEL_COLUMN_NAME not in nodes_schema.columns.keys():
nodes_schema.columns[LABEL_COLUMN_NAME] = "TEXT"
else:
if nodes_schema.columns[LABEL_COLUMN_NAME] != "TEXT":
raise ValueError(
f"Invalid nodes schema, '{LABEL_COLUMN_NAME}' column must be of type 'TEXT', not '{nodes_schema.columns[LABEL_COLUMN_NAME]}'."
)
if (
nodes_schema.columns[ID_COLUMN_NAME]
!= edges_schema.columns[SOURCE_COLUMN_NAME]
):
raise ValueError(
f"Invalid nodes schema, id column has different type to edges source/target columns: {nodes_schema.columns[ID_COLUMN_NAME]} != {edges_schema.columns[SOURCE_COLUMN_NAME]}"
)
values["edges_schema"] = edges_schema
values["nodes_schema"] = nodes_schema
return values
_nodes_table_obj: Optional[Table] = PrivateAttr(default=None)
_edges_table_obj: Optional[Table] = PrivateAttr(default=None)
_nx_graph = PrivateAttr(default={})
def _invalidate_other(self):
self._nodes_table_obj = None
self._edges_table_obj = None
def get_sqlalchemy_nodes_table(self) -> Table:
"""Return the sqlalchemy nodes table instance for this network datab."""
if self._nodes_table_obj is not None:
return self._nodes_table_obj
self._nodes_table_obj = Table(
NetworkDataTableType.NODES.value,
self.get_sqlalchemy_metadata(),
autoload_with=self.get_sqlalchemy_engine(),
)
return self._nodes_table_obj
def get_sqlalchemy_edges_table(self) -> Table:
"""Return the sqlalchemy edges table instance for this network datab."""
if self._edges_table_obj is not None:
return self._edges_table_obj
self._edges_table_obj = Table(
NetworkDataTableType.EDGES.value,
self.get_sqlalchemy_metadata(),
autoload_with=self.get_sqlalchemy_engine(),
)
return self._edges_table_obj
def insert_nodes(self, *nodes: Mapping[str, Any]):
"""Add nodes to a network data item.
Arguments:
nodes: a list of dicts with the nodes
"""
engine = self.get_sqlalchemy_engine()
nodes_table = self.get_sqlalchemy_nodes_table()
with engine.connect() as conn:
with conn.begin():
conn.execute(nodes_table.insert(), nodes)
def insert_edges(
self,
*edges: Mapping[str, Any],
existing_node_ids: Iterable[int] = None,
) -> Set[int]:
"""Add edges to a network data item.
All the edges need to have their node-ids registered already.
Arguments:
edges: a list of dicts with the edges
existing_node_ids: a set of ids that can be assumed to already exist, this is mainly for performance reasons
Returns:
a unique set of all node ids contained in source and target columns
"""
if existing_node_ids is None:
# TODO: run query
existing_node_ids = set()
else:
existing_node_ids = set(existing_node_ids)
required_node_ids = set((edge[SOURCE_COLUMN_NAME] for edge in edges))
required_node_ids.update(edge[TARGET_COLUMN_NAME] for edge in edges)
node_ids = list(required_node_ids.difference(existing_node_ids))
if node_ids:
self.insert_nodes(
*(
{ID_COLUMN_NAME: node_id, LABEL_COLUMN_NAME: str(node_id)}
for node_id in node_ids
)
)
engine = self.get_sqlalchemy_engine()
with engine.connect() as conn:
with conn.begin():
conn.execute(self.get_sqlalchemy_edges_table().insert(), edges)
return required_node_ids
def as_networkx_graph(self, graph_type: Type["nx.Graph"]) -> "nx.Graph":
"""Return the network data as a networkx graph object.
Arguments:
graph_type: the networkx Graph class to use
"""
if graph_type in self._nx_graph.keys():
return self._nx_graph[graph_type]
graph = graph_type()
engine = self.get_sqlalchemy_engine()
nodes = self.get_sqlalchemy_nodes_table()
edges = self.get_sqlalchemy_edges_table()
with engine.connect() as conn:
with conn.begin():
result = conn.execute(nodes.select())
for r in result:
row = dict(r)
node_id = row.pop(ID_COLUMN_NAME)
graph.add_node(node_id, **row)
result = conn.execute(edges.select())
for r in result:
row = dict(r)
source = row.pop(SOURCE_COLUMN_NAME)
target = row.pop(TARGET_COLUMN_NAME)
graph.add_edge(source, target, **row)
self._nx_graph[graph_type] = graph
return self._nx_graph[graph_type]
Attributes¶
edges_schema: SqliteTableSchema
pydantic-field
required
¶
The schema information for the edges table.
nodes_schema: SqliteTableSchema
pydantic-field
required
¶
The schema information for the nodes table.
Methods¶
as_networkx_graph(self, graph_type)
¶
Return the network data as a networkx graph object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph_type |
Type[nx.Graph] |
the networkx Graph class to use |
required |
Source code in network_analysis/models.py
def as_networkx_graph(self, graph_type: Type["nx.Graph"]) -> "nx.Graph":
"""Return the network data as a networkx graph object.
Arguments:
graph_type: the networkx Graph class to use
"""
if graph_type in self._nx_graph.keys():
return self._nx_graph[graph_type]
graph = graph_type()
engine = self.get_sqlalchemy_engine()
nodes = self.get_sqlalchemy_nodes_table()
edges = self.get_sqlalchemy_edges_table()
with engine.connect() as conn:
with conn.begin():
result = conn.execute(nodes.select())
for r in result:
row = dict(r)
node_id = row.pop(ID_COLUMN_NAME)
graph.add_node(node_id, **row)
result = conn.execute(edges.select())
for r in result:
row = dict(r)
source = row.pop(SOURCE_COLUMN_NAME)
target = row.pop(TARGET_COLUMN_NAME)
graph.add_edge(source, target, **row)
self._nx_graph[graph_type] = graph
return self._nx_graph[graph_type]
create_from_networkx_graph(graph)
classmethod
¶
Create a NetworkData
instance from a networkx Graph object.
Source code in network_analysis/models.py
@classmethod
def create_from_networkx_graph(cls, graph: "nx.Graph") -> "NetworkData":
"""Create a `NetworkData` instance from a networkx Graph object."""
edges_table = extract_edges_as_table(graph)
edges_schema = create_sqlite_schema_data_from_arrow_table(edges_table)
nodes_table = extract_nodes_as_table(graph)
nodes_schema = create_sqlite_schema_data_from_arrow_table(nodes_table)
network_data = NetworkData.create_in_temp_dir(
edges_schema=edges_schema, nodes_schema=nodes_schema, keep_unlocked=True
)
insert_table_data_into_network_graph(
network_data=network_data,
edges_table=edges_table,
nodes_table=nodes_table,
chunk_size=DEFAULT_NETWORK_DATA_CHUNK_SIZE,
)
network_data._lock_db()
return network_data
create_in_temp_dir(edges_schema=None, nodes_schema=None, keep_unlocked=False)
classmethod
¶
Source code in network_analysis/models.py
@classmethod
def create_in_temp_dir(
cls,
edges_schema: Union[None, SqliteTableSchema, Mapping] = None,
nodes_schema: Union[None, SqliteTableSchema, Mapping] = None,
keep_unlocked: bool = False,
):
temp_f = tempfile.mkdtemp()
db_path = os.path.join(temp_f, "network_data.sqlite")
def cleanup():
shutil.rmtree(db_path, ignore_errors=True)
atexit.register(cleanup)
db = cls(
db_file_path=db_path, edges_schema=edges_schema, nodes_schema=nodes_schema
)
db.create_if_not_exists()
db._unlock_db()
engine = db.get_sqlalchemy_engine()
db.edges_schema.create_table(
table_name=NetworkDataTableType.EDGES.value, engine=engine
)
db.nodes_schema.create_table(
table_name=NetworkDataTableType.NODES.value, engine=engine
)
if not keep_unlocked:
db._lock_db()
return db
get_sqlalchemy_edges_table(self)
¶
Return the sqlalchemy edges table instance for this network datab.
Source code in network_analysis/models.py
def get_sqlalchemy_edges_table(self) -> Table:
"""Return the sqlalchemy edges table instance for this network datab."""
if self._edges_table_obj is not None:
return self._edges_table_obj
self._edges_table_obj = Table(
NetworkDataTableType.EDGES.value,
self.get_sqlalchemy_metadata(),
autoload_with=self.get_sqlalchemy_engine(),
)
return self._edges_table_obj
get_sqlalchemy_nodes_table(self)
¶
Return the sqlalchemy nodes table instance for this network datab.
Source code in network_analysis/models.py
def get_sqlalchemy_nodes_table(self) -> Table:
"""Return the sqlalchemy nodes table instance for this network datab."""
if self._nodes_table_obj is not None:
return self._nodes_table_obj
self._nodes_table_obj = Table(
NetworkDataTableType.NODES.value,
self.get_sqlalchemy_metadata(),
autoload_with=self.get_sqlalchemy_engine(),
)
return self._nodes_table_obj
insert_edges(self, *edges, *, existing_node_ids=None)
¶
Add edges to a network data item.
All the edges need to have their node-ids registered already.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edges |
Mapping[str, Any] |
a list of dicts with the edges |
() |
existing_node_ids |
Iterable[int] |
a set of ids that can be assumed to already exist, this is mainly for performance reasons |
None |
Returns:
Type | Description |
---|---|
Set[int] |
a unique set of all node ids contained in source and target columns |
Source code in network_analysis/models.py
def insert_edges(
self,
*edges: Mapping[str, Any],
existing_node_ids: Iterable[int] = None,
) -> Set[int]:
"""Add edges to a network data item.
All the edges need to have their node-ids registered already.
Arguments:
edges: a list of dicts with the edges
existing_node_ids: a set of ids that can be assumed to already exist, this is mainly for performance reasons
Returns:
a unique set of all node ids contained in source and target columns
"""
if existing_node_ids is None:
# TODO: run query
existing_node_ids = set()
else:
existing_node_ids = set(existing_node_ids)
required_node_ids = set((edge[SOURCE_COLUMN_NAME] for edge in edges))
required_node_ids.update(edge[TARGET_COLUMN_NAME] for edge in edges)
node_ids = list(required_node_ids.difference(existing_node_ids))
if node_ids:
self.insert_nodes(
*(
{ID_COLUMN_NAME: node_id, LABEL_COLUMN_NAME: str(node_id)}
for node_id in node_ids
)
)
engine = self.get_sqlalchemy_engine()
with engine.connect() as conn:
with conn.begin():
conn.execute(self.get_sqlalchemy_edges_table().insert(), edges)
return required_node_ids
insert_nodes(self, *nodes)
¶
Add nodes to a network data item.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nodes |
Mapping[str, Any] |
a list of dicts with the nodes |
() |
Source code in network_analysis/models.py
def insert_nodes(self, *nodes: Mapping[str, Any]):
"""Add nodes to a network data item.
Arguments:
nodes: a list of dicts with the nodes
"""
engine = self.get_sqlalchemy_engine()
nodes_table = self.get_sqlalchemy_nodes_table()
with engine.connect() as conn:
with conn.begin():
conn.execute(nodes_table.insert(), nodes)
pre_validate(values)
classmethod
¶
Source code in network_analysis/models.py
@root_validator(pre=True)
def pre_validate(cls, values):
_edges_schema = values.get("edges_schema", None)
_nodes_schema = values.get("nodes_schema", None)
if _edges_schema is None:
suggested_id_type = "TEXT"
if _nodes_schema is not None:
if isinstance(_nodes_schema, Mapping):
suggested_id_type = _nodes_schema.get(ID_COLUMN_NAME, "TEXT")
elif isinstance(_nodes_schema, SqliteTableSchema):
suggested_id_type = _nodes_schema.columns.get(
ID_COLUMN_NAME, "TEXT"
)
edges_schema = SqliteTableSchema.construct(
columns={
SOURCE_COLUMN_NAME: suggested_id_type,
TARGET_COLUMN_NAME: suggested_id_type,
}
)
else:
if isinstance(_edges_schema, Mapping):
edges_schema = SqliteTableSchema(**_edges_schema)
elif not isinstance(_edges_schema, SqliteTableSchema):
raise ValueError(
f"Invalid data type for edges schema: {type(_edges_schema)}"
)
else:
edges_schema = _edges_schema
if (
edges_schema.columns[SOURCE_COLUMN_NAME]
!= edges_schema.columns[TARGET_COLUMN_NAME]
):
raise ValueError(
f"Invalid edges schema, source and edges columns have different type: {edges_schema[SOURCE_COLUMN_NAME]} != {edges_schema[TARGET_COLUMN_NAME]}"
)
if _nodes_schema is None:
_nodes_schema = SqliteTableSchema.construct(
columns={
ID_COLUMN_NAME: edges_schema.columns[SOURCE_COLUMN_NAME],
LABEL_COLUMN_NAME: "TEXT",
}
)
if isinstance(_nodes_schema, Mapping):
nodes_schema = SqliteTableSchema(**_nodes_schema)
elif isinstance(_nodes_schema, SqliteTableSchema):
nodes_schema = _nodes_schema
else:
raise ValueError(
f"Invalid data type for nodes schema: {type(_edges_schema)}"
)
if ID_COLUMN_NAME not in nodes_schema.columns.keys():
raise ValueError(
f"Invalid nodes schema: missing '{ID_COLUMN_NAME}' column."
)
if LABEL_COLUMN_NAME not in nodes_schema.columns.keys():
nodes_schema.columns[LABEL_COLUMN_NAME] = "TEXT"
else:
if nodes_schema.columns[LABEL_COLUMN_NAME] != "TEXT":
raise ValueError(
f"Invalid nodes schema, '{LABEL_COLUMN_NAME}' column must be of type 'TEXT', not '{nodes_schema.columns[LABEL_COLUMN_NAME]}'."
)
if (
nodes_schema.columns[ID_COLUMN_NAME]
!= edges_schema.columns[SOURCE_COLUMN_NAME]
):
raise ValueError(
f"Invalid nodes schema, id column has different type to edges source/target columns: {nodes_schema.columns[ID_COLUMN_NAME]} != {edges_schema.columns[SOURCE_COLUMN_NAME]}"
)
values["edges_schema"] = edges_schema
values["nodes_schema"] = nodes_schema
return values
NetworkGraphProperties (ValueMetadata)
pydantic-model
¶
File stats.
Source code in network_analysis/models.py
class NetworkGraphProperties(ValueMetadata):
"""File stats."""
_metadata_key = "graph_properties"
number_of_nodes: int = Field(description="Number of nodes in the network graph.")
properties_by_graph_type: List[PropertiesByGraphType] = Field(
description="Properties of the network data, by graph type."
)
@classmethod
def retrieve_supported_data_types(cls) -> Iterable[str]:
return ["network_data"]
@classmethod
def create_value_metadata(cls, value: Value) -> "NetworkGraphProperties":
from sqlalchemy import text
network_data: NetworkData = value.data
with network_data.get_sqlalchemy_engine().connect() as con:
result = con.execute(text("SELECT count(*) from nodes"))
num_rows = result.fetchone()[0]
result = con.execute(text("SELECT count(*) from edges"))
num_rows_eges = result.fetchone()[0]
result = con.execute(
text("SELECT COUNT(*) FROM (SELECT DISTINCT source, target FROM edges)")
)
num_edges_directed = result.fetchone()[0]
query = "SELECT COUNT(*) FROM edges WHERE rowid in (SELECT DISTINCT MIN(rowid) FROM (SELECT rowid, source, target from edges UNION ALL SELECT rowid, target, source from edges) GROUP BY source, target)"
result = con.execute(text(query))
num_edges_undirected = result.fetchone()[0]
directed = PropertiesByGraphType(
graph_type=GraphType.DIRECTED, number_of_edges=num_edges_directed
)
undirected = PropertiesByGraphType(
graph_type=GraphType.UNDIRECTED, number_of_edges=num_edges_undirected
)
directed_multi = PropertiesByGraphType(
graph_type=GraphType.DIRECTED_MULTI, number_of_edges=num_rows_eges
)
undirected_multi = PropertiesByGraphType(
graph_type=GraphType.UNDIRECTED_MULTI, number_of_edges=num_rows_eges
)
return cls(
number_of_nodes=num_rows,
properties_by_graph_type=[
directed,
undirected,
directed_multi,
undirected_multi,
],
)
Attributes¶
number_of_nodes: int
pydantic-field
required
¶
Number of nodes in the network graph.
properties_by_graph_type: List[kiara_plugin.network_analysis.models.PropertiesByGraphType]
pydantic-field
required
¶
Properties of the network data, by graph type.
create_value_metadata(value)
classmethod
¶
Source code in network_analysis/models.py
@classmethod
def create_value_metadata(cls, value: Value) -> "NetworkGraphProperties":
from sqlalchemy import text
network_data: NetworkData = value.data
with network_data.get_sqlalchemy_engine().connect() as con:
result = con.execute(text("SELECT count(*) from nodes"))
num_rows = result.fetchone()[0]
result = con.execute(text("SELECT count(*) from edges"))
num_rows_eges = result.fetchone()[0]
result = con.execute(
text("SELECT COUNT(*) FROM (SELECT DISTINCT source, target FROM edges)")
)
num_edges_directed = result.fetchone()[0]
query = "SELECT COUNT(*) FROM edges WHERE rowid in (SELECT DISTINCT MIN(rowid) FROM (SELECT rowid, source, target from edges UNION ALL SELECT rowid, target, source from edges) GROUP BY source, target)"
result = con.execute(text(query))
num_edges_undirected = result.fetchone()[0]
directed = PropertiesByGraphType(
graph_type=GraphType.DIRECTED, number_of_edges=num_edges_directed
)
undirected = PropertiesByGraphType(
graph_type=GraphType.UNDIRECTED, number_of_edges=num_edges_undirected
)
directed_multi = PropertiesByGraphType(
graph_type=GraphType.DIRECTED_MULTI, number_of_edges=num_rows_eges
)
undirected_multi = PropertiesByGraphType(
graph_type=GraphType.UNDIRECTED_MULTI, number_of_edges=num_rows_eges
)
return cls(
number_of_nodes=num_rows,
properties_by_graph_type=[
directed,
undirected,
directed_multi,
undirected_multi,
],
)
retrieve_supported_data_types()
classmethod
¶
Source code in network_analysis/models.py
@classmethod
def retrieve_supported_data_types(cls) -> Iterable[str]:
return ["network_data"]
PropertiesByGraphType (BaseModel)
pydantic-model
¶
Properties of graph data, if interpreted as a specific graph type.
Source code in network_analysis/models.py
class PropertiesByGraphType(BaseModel):
"""Properties of graph data, if interpreted as a specific graph type."""
graph_type: GraphType = Field(description="The graph type name.")
number_of_edges: int = Field(description="The number of edges.")