Skip to content

utils

NetworkDataTabularWrap (TabularWrap)

Source code in network_analysis/utils.py
class NetworkDataTabularWrap(TabularWrap):
    def __init__(self, db: "NetworkData", table_type: NetworkDataTableType):
        self._db: NetworkData = db
        self._table_type: NetworkDataTableType = table_type
        super().__init__()

    @property
    def _table_name(self):
        return self._table_type.value

    def retrieve_number_of_rows(self) -> int:

        from sqlalchemy import text

        with self._db.get_sqlalchemy_engine().connect() as con:
            result = con.execute(text(f"SELECT count(*) from {self._table_name}"))
            num_rows = result.fetchone()[0]

        return num_rows

    def retrieve_column_names(self) -> typing.Iterable[str]:

        from sqlalchemy import inspect

        engine = self._db.get_sqlalchemy_engine()
        inspector = inspect(engine)
        columns = inspector.get_columns(self._table_type.value)
        result = [column["name"] for column in columns]
        return result

    def slice(
        self, offset: int = 0, length: typing.Optional[int] = None
    ) -> "TabularWrap":

        from sqlalchemy import text

        query = f"SELECT * FROM {self._table_name}"
        if length:
            query = f"{query} LIMIT {length}"
        else:
            query = f"{query} LIMIT {self.num_rows}"
        if offset > 0:
            query = f"{query} OFFSET {offset}"
        with self._db.get_sqlalchemy_engine().connect() as con:
            result = con.execute(text(query))
            result_dict: typing.Dict[str, typing.List[typing.Any]] = {}
            for cn in self.column_names:
                result_dict[cn] = []
            for r in result:
                for i, cn in enumerate(self.column_names):
                    result_dict[cn].append(r[i])

        return DictTabularWrap(result_dict)

    def to_pydict(self) -> typing.Mapping:

        from sqlalchemy import text

        query = f"SELECT * FROM {self._table_name}"

        with self._db.get_sqlalchemy_engine().connect() as con:
            result = con.execute(text(query))
            result_dict: typing.Dict[str, typing.List[typing.Any]] = {}
            for cn in self.column_names:
                result_dict[cn] = []
            for r in result:
                for i, cn in enumerate(self.column_names):
                    result_dict[cn].append(r[i])

        return result_dict
retrieve_column_names(self)
Source code in network_analysis/utils.py
def retrieve_column_names(self) -> typing.Iterable[str]:

    from sqlalchemy import inspect

    engine = self._db.get_sqlalchemy_engine()
    inspector = inspect(engine)
    columns = inspector.get_columns(self._table_type.value)
    result = [column["name"] for column in columns]
    return result
retrieve_number_of_rows(self)
Source code in network_analysis/utils.py
def retrieve_number_of_rows(self) -> int:

    from sqlalchemy import text

    with self._db.get_sqlalchemy_engine().connect() as con:
        result = con.execute(text(f"SELECT count(*) from {self._table_name}"))
        num_rows = result.fetchone()[0]

    return num_rows
slice(self, offset=0, length=None)
Source code in network_analysis/utils.py
def slice(
    self, offset: int = 0, length: typing.Optional[int] = None
) -> "TabularWrap":

    from sqlalchemy import text

    query = f"SELECT * FROM {self._table_name}"
    if length:
        query = f"{query} LIMIT {length}"
    else:
        query = f"{query} LIMIT {self.num_rows}"
    if offset > 0:
        query = f"{query} OFFSET {offset}"
    with self._db.get_sqlalchemy_engine().connect() as con:
        result = con.execute(text(query))
        result_dict: typing.Dict[str, typing.List[typing.Any]] = {}
        for cn in self.column_names:
            result_dict[cn] = []
        for r in result:
            for i, cn in enumerate(self.column_names):
                result_dict[cn].append(r[i])

    return DictTabularWrap(result_dict)
to_pydict(self)
Source code in network_analysis/utils.py
def to_pydict(self) -> typing.Mapping:

    from sqlalchemy import text

    query = f"SELECT * FROM {self._table_name}"

    with self._db.get_sqlalchemy_engine().connect() as con:
        result = con.execute(text(query))
        result_dict: typing.Dict[str, typing.List[typing.Any]] = {}
        for cn in self.column_names:
            result_dict[cn] = []
        for r in result:
            for i, cn in enumerate(self.column_names):
                result_dict[cn].append(r[i])

    return result_dict

convert_graphml_type_to_sqlite(data_type)

Source code in network_analysis/utils.py
def convert_graphml_type_to_sqlite(data_type: str) -> str:

    type_map = {
        "boolean": "INTEGER",
        "int": "INTEGER",
        "long": "INTEGER",
        "float": "REAL",
        "double": "REAL",
        "string": "TEXT",
    }

    return type_map[data_type]

extract_edges_as_table(graph)

Source code in network_analysis/utils.py
def extract_edges_as_table(graph: "nx.Graph"):

    # adapted from networx code
    # License: 3-clause BSD license
    # Copyright (C) 2004-2022, NetworkX Developers

    import networkx as nx
    import pyarrow as pa

    edgelist = graph.edges(data=True)
    source_nodes = [s for s, _, _ in edgelist]
    target_nodes = [t for _, t, _ in edgelist]

    all_attrs: typing.Set[str] = set().union(*(d.keys() for _, _, d in edgelist))  # type: ignore

    if SOURCE_COLUMN_NAME in all_attrs:
        raise nx.NetworkXError(
            f"Source name {SOURCE_COLUMN_NAME} is an edge attribute name"
        )
    if SOURCE_COLUMN_NAME in all_attrs:
        raise nx.NetworkXError(
            f"Target name {SOURCE_COLUMN_NAME} is an edge attribute name"
        )

    nan = float("nan")
    edge_attr = {k: [d.get(k, nan) for _, _, d in edgelist] for k in all_attrs}

    edge_lists = {
        SOURCE_COLUMN_NAME: source_nodes,
        TARGET_COLUMN_NAME: target_nodes,
    }

    edge_lists.update(edge_attr)
    edges_table = pa.Table.from_pydict(mapping=edge_lists)

    return edges_table

extract_nodes_as_table(graph)

Source code in network_analysis/utils.py
def extract_nodes_as_table(graph: "nx.Graph"):

    # adapted from networx code
    # License: 3-clause BSD license
    # Copyright (C) 2004-2022, NetworkX Developers

    import networkx as nx
    import pyarrow as pa

    nodelist = graph.nodes(data=True)

    node_ids = [n for n, _ in nodelist]

    all_attrs: typing.Set[str] = set().union(*(d.keys() for _, d in nodelist))  # type: ignore

    if ID_COLUMN_NAME in all_attrs:
        raise nx.NetworkXError(
            f"Id column name {ID_COLUMN_NAME} is an node attribute name"
        )
    if SOURCE_COLUMN_NAME in all_attrs:
        raise nx.NetworkXError(
            f"Target name {SOURCE_COLUMN_NAME} is an edge attribute name"
        )

    nan = float("nan")
    node_attr = {k: [d.get(k, nan) for _, d in nodelist] for k in all_attrs}

    node_attr[ID_COLUMN_NAME] = node_ids
    nodes_table = pa.Table.from_pydict(mapping=node_attr)

    return nodes_table

insert_table_data_into_network_graph(network_data, edges_table, edges_column_map=None, nodes_table=None, nodes_column_map=None, chunk_size=1024)

Source code in network_analysis/utils.py
def insert_table_data_into_network_graph(
    network_data: "NetworkData",
    edges_table: "pa.Table",
    edges_column_map: typing.Optional[typing.Mapping[str, str]] = None,
    nodes_table: typing.Optional["pa.Table"] = None,
    nodes_column_map: typing.Optional[typing.Mapping[str, str]] = None,
    chunk_size: int = DEFAULT_NETWORK_DATA_CHUNK_SIZE,
):

    added_node_ids = set()

    if edges_column_map is None:
        edges_column_map = {}
    if nodes_column_map is None:
        nodes_column_map = {}

    if nodes_table is not None:
        for batch in nodes_table.to_batches(chunk_size):
            batch_dict = batch.to_pydict()

            if nodes_column_map:
                for k, v in nodes_column_map.items():
                    if k in batch_dict.keys():
                        if k == ID_COLUMN_NAME and v == LABEL_COLUMN_NAME:
                            _data = batch_dict.get(k)
                        else:
                            _data = batch_dict.pop(k)
                            if v in batch_dict.keys():
                                raise Exception(
                                    "Duplicate nodes column name after mapping: {v}"
                                )
                        batch_dict[v] = _data
            if LABEL_COLUMN_NAME not in batch_dict.keys():
                batch_dict[LABEL_COLUMN_NAME] = (
                    str(x) for x in batch_dict[ID_COLUMN_NAME]
                )

            ids = batch_dict[ID_COLUMN_NAME]
            data = [dict(zip(batch_dict, t)) for t in zip(*batch_dict.values())]
            network_data.insert_nodes(*data)

            added_node_ids.update(ids)

    for batch in edges_table.to_batches(chunk_size):

        batch_dict = batch.to_pydict()

        for k, v in edges_column_map.items():
            if k in batch_dict.keys():
                _data = batch_dict.pop(k)
                if v in batch_dict.keys():
                    raise Exception("Duplicate edges column name after mapping: {v}")
                batch_dict[v] = _data

        data = [dict(zip(batch_dict, t)) for t in zip(*batch_dict.values())]

        all_node_ids = network_data.insert_edges(
            *data,
            existing_node_ids=added_node_ids,
        )
        added_node_ids.update(all_node_ids)