diff --git a/examples/gds-example.ipynb b/examples/gds-example.ipynb index c38d191..77bd74f 100644 --- a/examples/gds-example.ipynb +++ b/examples/gds-example.ipynb @@ -1,65 +1,107 @@ { "cells": [ { - "cell_type": "markdown", "metadata": {}, - "source": [ - "# Visualizing Neo4j Graph Data Science (GDS) Graphs" - ] + "cell_type": "markdown", + "source": "# Visualizing Neo4j Graph Data Science (GDS) Graphs", + "id": "12f02eab9ab5ed32" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "%pip install graphdatascience\n", "%pip install matplotlib" - ] + ], + "id": "2659974f9eb8c448" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ - "## Setup GDS graph" - ] + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" + ], + "id": "a1ba28abfe91b628" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Setup GDS graph\n", + "\n", + "To use GDS, you can either use GDS as a plugin or Aura Graph Analytics.\n", + "In the following, you can choose:\n", + "\n", + " * Provide Aura API credentials and and use Aura Graph Analytics.\n", + " * Use Neo4j + GDS Plugin.\n", + "\n", + "For more information, see the [GDS documentation](https://neo4j.com/docs/graph-data-science/current/installation/)." + ], + "id": "c6dc0292966144ba" + }, + { + "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "import os\n", "\n", + "from graphdatascience.session import GdsSessions, DbmsConnectionInfo, AuraAPICredentials, SessionMemory\n", "from graphdatascience import GraphDataScience\n", "\n", "# Get Neo4j DB URI, credentials and name from environment if applicable\n", - "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", - "NEO4J_AUTH = (\"neo4j\", None)\n", - "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", - "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", - " NEO4J_AUTH = (\n", - " os.environ.get(\"NEO4J_USER\"),\n", - " os.environ.get(\"NEO4J_PASSWORD\"),\n", - " )\n", - "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)" - ] + "db_connection = DbmsConnectionInfo(\n", + " aura_instance_id=os.environ.get(\"AURA_INSTANCEID\"),\n", + " username=os.environ[\"NEO4J_USERNAME\"],\n", + " password=os.environ[\"NEO4J_PASSWORD\"],\n", + " uri=os.environ[\"NEO4J_URI\"],\n", + ")\n", + "\n", + "session_name = \"neo4j-viz-gds-example\"\n", + "if os.environ.get(\"AURA_API_CLIENT_ID\"):\n", + " # Use Aura Graph Analytics\n", + " sessions = GdsSessions(api_credentials=AuraAPICredentials(\n", + " client_id=os.environ[\"AURA_API_CLIENT_ID\"],\n", + " client_secret=os.environ[\"AURA_API_CLIENT_SECRET\"],\n", + " project_id=os.environ.get(\"AURA_API_PROJECT_ID\"),\n", + " ))\n", + " gds = sessions.get_or_create(session_name=session_name, memory=SessionMemory.m_2GB, db_connection=db_connection)\n", + "else:\n", + " # Use GDS Plugin\n", + " sessions = None\n", + " gds = GraphDataScience(\n", + " endpoint=db_connection.get_uri(),\n", + " auth=(db_connection.username, db_connection.password),\n", + " )" + ], + "id": "35d47e906f813781" }, { + "metadata": {}, "cell_type": "code", + "outputs": [], "execution_count": null, + "source": "G = gds.graph.load_cora(graph_name=\"cora\")", + "id": "2ef09c75d64ebfff" + }, + { "metadata": {}, - "outputs": [], - "source": [ - "G = gds.graph.load_cora(graph_name=\"cora\")" - ] + "cell_type": "markdown", + "source": "", + "id": "7d617acd466fc01f" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "# Run some algorithms to use later for visualization\n", "gds.nodeSimilarity.mutate(\n", @@ -67,3276 +109,99 @@ ")\n", "gds.pageRank.mutate(G, mutateProperty=\"pagerank\")\n", "gds.louvain.mutate(G, mutateProperty=\"componentId\")" - ] + ], + "id": "d37e9519da4cbfd8" }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "## Visualization" - ] + "cell_type": "markdown", + "source": "## Visualization", + "id": "ff2d387d6dccb583" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "from neo4j_viz.gds import from_gds\n", "\n", - "VG = from_gds(gds, G, max_node_count=500)" - ] + "VG = from_gds(\n", + " gds,\n", + " G,\n", + " max_node_count=100,\n", + ")" + ], + "id": "7299dbcea5d5b5bd" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], - "source": [] - }, - { "cell_type": "code", - "execution_count": null, - "metadata": {}, "outputs": [], - "source": [ - "len(VG.nodes)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "VG.render()" - ] + "execution_count": null, + "source": "VG.render()", + "id": "acf2d3dd5e3d1267" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "### Changing captions\n", "\n", "We can also change the node captions, if we want to see something other that the node labels.\n", "For this dataset it might make sense to caption by scientific subject." - ] + ], + "id": "568a08b5fe175ef1" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "for node in VG.nodes:\n", " node.caption = str(node.properties[\"subject\"])" - ] + ], + "id": "279c18f8799f8928" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 9, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "VG.render()" - ] + "outputs": [], + "execution_count": null, + "source": "VG.render()", + "id": "e8363f2fe612bf62" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "## Sizing the nodes\n", "\n", "Next, we can size the nodes by their pageRank score to show their importance." - ] + ], + "id": "986d86d1be93033d" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "VG.resize_nodes(property=\"pagerank\")\n", - "VG.color_nodes(property=\"componentId\")\n", "VG.render()" - ] + ], + "id": "c36d91d10b67b533" }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "### Coloring" - ] + "cell_type": "markdown", + "source": "### Coloring", + "id": "d2c237bf1fd2cec9" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "There are two main ways of coloring the nodes of a graph:\n", "\n", @@ -3344,3229 +209,67 @@ "* By continuous color space, in which case nodes will be colored according to a range of colors, according to their field or property value\n", "\n", "We will start be coloring the node based on our discrete node property \"subject\" using the default colors." - ] + ], + "id": "5c7e54e9b59699c6" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 11, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], + "execution_count": null, "source": [ - "VG.color_nodes(property=\"componentId\")\n", + "VG.color_nodes(property=\"subject\")\n", "VG.render()" - ] + ], + "id": "88eecc067e28ea0f" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "Now, let us color by our continuous node field \"size\" that we computed above with PageRank, again using the default colors.\n", "We set `override=True` so as to replace the previous coloring completely.\n", "Note how the nodes are colored from yellow to purple, and how that also corresponds to the nodes' sizes." - ] + ], + "id": "ddb420af6158ad2a" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 12, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], + "execution_count": null, "source": [ "from neo4j_viz.colors import ColorSpace\n", "\n", "VG.color_nodes(field=\"size\", color_space=ColorSpace.CONTINUOUS, override=True)\n", "VG.render()" - ] + ], + "id": "7cda3d298d834dec" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "#### Custom coloring\n", "\n", "In some cases, the default colors are too few.\n", "For example, if you have many communities, you might need a lot more colors.\n" - ] + ], + "id": "80923c3f3e5f8abf" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], - "source": [ - "%pip install matplotlib, palettable" - ] + "execution_count": null, + "source": "%pip install matplotlib, palettable", + "id": "a9b7ea72daf706c4" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "from palettable.colorbrewer.qualitative import Dark2_7\n", "import matplotlib.colors as mcolors\n", @@ -6578,1603 +281,20 @@ "colors = [mcolors.rgb2hex(linear_color_map(i)) for i in range(number_of_components)]\n", "\n", "VG.color_nodes(property=\"componentId\", colors=colors, override=True)" - ] + ], + "id": "e7c333a570ea3bdf" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 15, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "VG.render()" - ] + "outputs": [], + "execution_count": null, + "source": "VG.render()", + "id": "2bcf029dd33a08b9" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "### Render options\n", "\n", @@ -8185,1614 +305,32 @@ "* renderer\n", "* zoom level\n", "* initial position " - ] + ], + "id": "583745ea477a948f" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 16, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], + "execution_count": null, "source": [ "from neo4j_viz import Layout\n", "\n", "VG.render(layout=Layout.HIERARCHICAL, initial_zoom=0.1, pan_position=(2000, 0))" - ] + ], + "id": "f2dca2f10fc3b3db" }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "## Saving the visualization" - ] + "cell_type": "markdown", + "source": "## Saving the visualization", + "id": "1b48c5fd3677088" }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "from neo4j_viz.options import Renderer\n", "\n", @@ -9801,36 +339,40 @@ "# Save the visualization to a file\n", "with open(\"out/cora.html\", \"w\") as f:\n", " f.write(VG.render(renderer=Renderer.CANVAS).data)" - ] + ], + "id": "d1110d89d939af6a" }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "## Cleanup\n", "\n", "Lets cleanup the graphs we created in GDS." - ] + ], + "id": "fa7f396a5495d516" }, { + "metadata": {}, "cell_type": "code", + "outputs": [], "execution_count": null, - "metadata": { - "tags": [ - "teardown" - ] - }, + "source": "gds.graph.drop(\"cora\")", + "id": "cf86d2149972a2c9" + }, + { + "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ - "gds.graph.drop(\"cora\")" - ] + "if sessions:\n", + " sessions.delete(session_name=session_name)" + ], + "id": "db6e7d51b612cc26" } ], - "metadata": { - "language_info": { - "name": "python" - } - }, + "metadata": {}, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/justfile b/justfile index ea83288..da1d765 100644 --- a/justfile +++ b/justfile @@ -18,13 +18,23 @@ py-test-gds: trap "cd $ENV_DIR && docker compose down" EXIT cd $ENV_DIR && docker compose up -d cd - + cd python-wrapper && \ NEO4J_URI=bolt://localhost:7687 \ NEO4J_USER=neo4j \ NEO4J_PASSWORD=password \ NEO4J_DB=neo4j \ - cd python-wrapper && uv run --group dev --extra gds pytest tests --include-neo4j-and-gds + uv run --group dev --extra gds pytest tests --include-neo4j-and-gds cd .. +py-test-gds-sessions: + #!/usr/bin/env bash + cd python-wrapper && \ + GDS_SESSION_URI=bolt://localhost:7688 \ + NEO4J_URI=bolt://localhost:7687 \ + NEO4J_USER=neo4j \ + NEO4J_PASSWORD=password \ + uv run --group dev --extra gds pytest tests --include-neo4j-and-gds + local-neo4j-setup: #!/usr/bin/env bash set -e diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml index d11343d..81cd409 100644 --- a/python-wrapper/pyproject.toml +++ b/python-wrapper/pyproject.toml @@ -42,7 +42,7 @@ requires-python = ">=3.10" [project.optional-dependencies] pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"] -gds = ["graphdatascience>=1, <2"] +gds = ["graphdatascience>=1.20, <2"] neo4j = ["neo4j"] snowflake = ["snowflake-snowpark-python>=1, <2"] @@ -76,9 +76,9 @@ notebook = [ "palettable>=3.3.3", "matplotlib>=3.9.4", "snowflake-snowpark-python==1.42.0", - "dotenv", "requests", "marimo", + "python-dotenv" ] [project.urls] @@ -174,9 +174,3 @@ exclude = [ ] plugins = ['pydantic.mypy'] untyped_calls_exclude=["nbconvert"] - -[tool.marimo.runtime] -output_max_bytes = 20_000_000 -# -#[tool.marimo.server] -#follow_symlink = true diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index e6a399a..8041431 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -2,11 +2,13 @@ import warnings from itertools import chain -from typing import Optional, cast +from typing import Optional, cast, Collection from uuid import uuid4 import pandas as pd from graphdatascience import Graph, GraphDataScience +from graphdatascience.graph.v2 import GraphV2 +from graphdatascience.session import AuraGraphDataScience from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace @@ -15,48 +17,38 @@ def _fetch_node_dfs( - gds: GraphDataScience, - G: Graph, + gds: GraphDataScience | AuraGraphDataScience, + G: GraphV2, node_properties_by_label: dict[str, list[str]], - node_labels: list[str], + node_labels: Collection[str], additional_db_node_properties: list[str], ) -> dict[str, pd.DataFrame]: return { - lbl: gds.graph.nodeProperties.stream( + lbl: gds.v2.graph.node_properties.stream( G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], - separate_property_columns=True, db_node_properties=additional_db_node_properties, ) for lbl in node_labels } -def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]: - rel_types = G.relationship_types() - - rel_props = {rel_type: G.relationship_properties(rel_type) for rel_type in rel_types} +def _fetch_rel_dfs(gds: GraphDataScience, G: GraphV2) -> list[pd.DataFrame]: + rel_props = G.relationship_properties() rel_dfs: list[pd.DataFrame] = [] # Have to call per stream per relationship type as there was a bug in GDS < 2.21 for rel_type, props in rel_props.items(): - assert isinstance(props, list) - if len(props) > 0: - rel_df = gds.graph.relationshipProperties.stream( - G, relationship_types=rel_type, relationship_properties=list(props), separate_property_columns=True - ) - else: - rel_df = gds.graph.relationships.stream(G, relationship_types=[rel_type]) - + rel_df = gds.v2.graph.relationships.stream(G, relationship_types=[rel_type], relationship_properties=list(props)) rel_dfs.append(rel_df) return rel_dfs def from_gds( - gds: GraphDataScience, - G: Graph, + gds: GraphDataScience | AuraGraphDataScience, + G: Graph | GraphV2, node_properties: Optional[list[str]] = None, db_node_properties: Optional[list[str]] = None, max_node_count: int = 10_000, @@ -76,9 +68,9 @@ def from_gds( Parameters ---------- - gds : GraphDataScience - GraphDataScience object. - G : Graph + gds + GraphDataScience object. AuraGraphDataScience object if using Aura Graph Analytics. + G Graph object. node_properties : list[str], optional Additional properties to include in the visualization node, by default None which means that all node @@ -91,37 +83,39 @@ def from_gds( """ if db_node_properties is None: db_node_properties = [] + if isinstance(G, Graph): + G_v2 = gds.v2.graph.get(G.name()) + else: + G_v2 = G - node_properties_from_gds = G.node_properties() - assert isinstance(node_properties_from_gds, pd.Series) - actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict()) - all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values())) + gds_properties_per_label = G_v2.node_properties() + all_gds_properties = list(chain.from_iterable(gds_properties_per_label.values())) node_properties_by_label_sets: dict[str, set[str]] = dict() if node_properties is None: - node_properties_by_label_sets = {k: set(v) for k, v in actual_node_properties.items()} + node_properties_by_label_sets = {k: set(v) for k, v in gds_properties_per_label.items()} else: for prop in node_properties: - if prop not in all_actual_node_properties: + if prop not in all_gds_properties: raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'") - for label, props in actual_node_properties.items(): + for label, props in gds_properties_per_label.items(): node_properties_by_label_sets[label] = { - prop for prop in actual_node_properties[label] if prop in node_properties + prop for prop in gds_properties_per_label[label] if prop in node_properties } node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()} - node_count = G.node_count() + node_count = G_v2.node_count() if node_count > max_node_count: warnings.warn( - f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed" + f"The '{G_v2.name()}' projection's node count ({G_v2.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed" ) sampling_ratio = float(max_node_count) / node_count sample_name = f"neo4j-viz_sample_{uuid4()}" - G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True) + G_fetched, _ = gds.v2.graph.sample.rwr(G_v2, sample_name, sampling_ratio=sampling_ratio, node_label_stratification=True) else: - G_fetched = G + G_fetched = G_v2 property_name = None try: @@ -129,12 +123,12 @@ def from_gds( # as a temporary property to ensure that we have at least one property for each label to fetch if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0: property_name = f"neo4j-viz_property_{uuid4()}" - gds.degree.mutate(G_fetched, mutateProperty=property_name) + gds.v2.degree_centrality.mutate(G_fetched, mutate_property=property_name) for props in node_properties_by_label.values(): props.append(property_name) node_dfs = _fetch_node_dfs( - gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), db_node_properties + gds, G_fetched, node_properties_by_label, node_properties_by_label.keys(), db_node_properties ) if property_name is not None: for df in node_dfs.values(): @@ -145,7 +139,7 @@ def from_gds( if G_fetched.name() != G.name(): G_fetched.drop() elif property_name is not None: - gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name]) + gds.v2.graph.node_properties.drop(G_fetched, node_properties=[property_name]) for df in node_dfs.values(): if property_name is not None and property_name in df.columns: @@ -154,7 +148,7 @@ def from_gds( node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates(subset=["nodeId"]) for lbl, df in node_dfs.items(): - if "labels" in all_actual_node_properties: + if "labels" in all_gds_properties: df.rename(columns={"labels": "__labels"}, inplace=True) df["labels"] = lbl diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py index 40f7f4e..2ecb600 100644 --- a/python-wrapper/tests/conftest.py +++ b/python-wrapper/tests/conftest.py @@ -1,4 +1,5 @@ import os +import random from typing import Any, Generator import pytest @@ -31,56 +32,73 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: @pytest.fixture(scope="package") -def aura_ds_instance() -> Generator[Any, None, None]: +def aura_db_instance() -> Generator[Any, None, None]: if os.environ.get("AURA_API_CLIENT_ID", None) is None: yield None return - from tests.gds_helper import aura_api, create_aurads_instance + from tests.gds_helper import aura_api, create_auradb_instance api = aura_api() - id, dbms_connection_info = create_aurads_instance(api) + dbms_connection_info = create_auradb_instance(api) + old_uri = os.environ.get("NEO4J_URI", "") # setting as environment variables to run notebooks with this connection os.environ["NEO4J_URI"] = dbms_connection_info.get_uri() assert isinstance(dbms_connection_info.username, str) os.environ["NEO4J_USER"] = dbms_connection_info.username assert isinstance(dbms_connection_info.password, str) os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password + old_instance = os.environ.get("AURA_INSTANCEID", "") + if dbms_connection_info.aura_instance_id: + os.environ["AURA_INSTANCEID"] = dbms_connection_info.aura_instance_id + yield dbms_connection_info # Clear Neo4j_URI after test (rerun should create a new instance) - os.environ["NEO4J_URI"] = "" - api.delete_instance(id) + os.environ["NEO4J_URI"] = old_uri + os.environ["AURA_INSTANCEID"] = old_instance + assert dbms_connection_info.aura_instance_id is not None + api.delete_instance(dbms_connection_info.aura_instance_id) @pytest.fixture(scope="package") -def gds(aura_ds_instance: Any) -> Generator[Any, None, None]: - from graphdatascience import GraphDataScience +def gds(aura_db_instance: Any) -> Generator[Any, None, None]: + from graphdatascience.session import SessionMemory + + from tests.gds_helper import connect_to_plugin_gds, gds_sessions, connect_to_local_gds_session - from tests.gds_helper import connect_to_plugin_gds + if aura_db_instance: + sessions = gds_sessions() - if aura_ds_instance: - yield GraphDataScience( - endpoint=aura_ds_instance.uri, - auth=(aura_ds_instance.username, aura_ds_instance.password), - aura_ds=True, - database="neo4j", + gds = sessions.get_or_create( + f"neo4j-viz-ci-{os.environ.get('GITHUB_RUN_ID', random.randint(0, 10**6))}", + memory=SessionMemory.m_2GB, + db_connection=aura_db_instance, ) + + yield gds + gds.delete() else: - NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") - gds = connect_to_plugin_gds(NEO4J_URI) + neo4j_uri = os.environ["NEO4J_URI"] + neo4j_auth = (os.environ.get("NEO4J_USER", "neo4j"), os.environ.get("NEO4J_PASSWORD", "password")) + + session_uri = os.environ.get("GDS_SESSION_URI") + if session_uri: + gds = connect_to_local_gds_session(session_uri, neo4j_uri, neo4j_auth) # type: ignore + else: + gds = connect_to_plugin_gds(neo4j_uri, neo4j_auth) # type: ignore yield gds gds.close() @pytest.fixture(scope="package") -def neo4j_driver(aura_ds_instance: Any) -> Generator[Any, None, None]: +def neo4j_driver(aura_db_instance: Any) -> Generator[Any, None, None]: import neo4j - if aura_ds_instance: + if aura_db_instance: driver = neo4j.GraphDatabase.driver( - aura_ds_instance.uri, auth=(aura_ds_instance.username, aura_ds_instance.password) + aura_db_instance.uri, auth=(aura_db_instance.username, aura_db_instance.password) ) else: NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py index e5a0d3d..74246ce 100644 --- a/python-wrapper/tests/gds_helper.py +++ b/python-wrapper/tests/gds_helper.py @@ -1,12 +1,14 @@ import os import re -from graphdatascience import GraphDataScience +from graphdatascience import GdsSessions, GraphDataScience +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.semantic_version.semantic_version import SemanticVersion -from graphdatascience.session import DbmsConnectionInfo, SessionMemory +from graphdatascience.session import AuraAPICredentials, DbmsConnectionInfo, SessionMemory, AuraGraphDataScience from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import InstanceCreateDetails from graphdatascience.version import __version__ +from neo4j import GraphDatabase def parse_version(version: str) -> SemanticVersion: @@ -26,12 +28,22 @@ def parse_version(version: str) -> SemanticVersion: GDS_VERSION = parse_version(__version__) -def connect_to_plugin_gds(uri: str) -> GraphDataScience: - NEO4J_AUTH = ("neo4j", "password") - if os.environ.get("NEO4J_USER"): - NEO4J_AUTH = (os.environ.get("NEO4J_USER", "DUMMY"), os.environ.get("NEO4J_PASSWORD", "neo4j")) +def connect_to_plugin_gds(uri: str, auth: tuple[str, str]) -> GraphDataScience: + return GraphDataScience(endpoint=uri, auth=auth, database="neo4j") - return GraphDataScience(endpoint=uri, auth=NEO4J_AUTH, database="neo4j") + +def connect_to_local_gds_session(session_uri: str, db_uri: str, db_auth: tuple[str, str]) -> AuraGraphDataScience: + session_bolt_connection_info = DbmsConnectionInfo(uri=session_uri, username="neo4j", password="password") + db_connection_info = DbmsConnectionInfo(uri=db_uri, username=db_auth[0], + password=db_auth[1]) + + return AuraGraphDataScience.create( + session_bolt_connection_info=session_bolt_connection_info, + arrow_authentication=UsernamePasswordAuthentication(session_bolt_connection_info.username, + session_bolt_connection_info.password), + session_lifecycle_manager=None, # type: ignore + db_endpoint=db_connection_info + ) def aura_api() -> AuraApi: @@ -49,21 +61,29 @@ def aura_api() -> AuraApi: ) -def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]: - # Switch to Sessions once they can be created without a DB +def gds_sessions() -> GdsSessions: + return GdsSessions( + api_credentials=AuraAPICredentials( + client_id=os.environ["AURA_API_CLIENT_ID"], + client_secret=os.environ["AURA_API_CLIENT_SECRET"], + project_id=os.environ.get("AURA_API_TENANT_ID"), + ) + ) + + +def create_auradb_instance(api: AuraApi) -> DbmsConnectionInfo: instance_details: InstanceCreateDetails = api.create_instance( - name="ci-neo4j-viz-session", - memory=SessionMemory.m_8GB.value, + name="ci-neo4j-viz-db", + memory=SessionMemory.m_2GB.value, cloud_provider="gcp", region="europe-west1", + type="enterprise-db", ) wait_result = api.wait_for_instance_running(instance_id=instance_details.id) if wait_result.error: raise Exception(f"Error while waiting for instance to be running: {wait_result.error}") - return instance_details.id, DbmsConnectionInfo( - uri=wait_result.connection_url, - username="neo4j", - password=instance_details.password, + return DbmsConnectionInfo( + username="neo4j", password=instance_details.password, aura_instance_id=instance_details.id ) diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index fb078aa..a35b822 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -20,12 +20,26 @@ def db_setup(gds: Any) -> Generator[None, None, None]: gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n") +def project_graph(gds: Any) -> Any: + from graphdatascience import GraphDataScience + from graphdatascience.session import AuraGraphDataScience + if isinstance(gds, GraphDataScience): + return gds.graph.project("g2", ["*"], "*") + elif isinstance(gds, AuraGraphDataScience): + return gds.v2.graph.project( + "g2", query=""" + MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m) + """ + ) + raise Exception(f"Unsupported GDS type {type(gds)}") + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.requires_neo4j_and_gds def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None: from neo4j_viz.gds import from_gds - with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G: + with project_graph(gds) as G: VG = from_gds(gds, G, db_node_properties=["name"]) assert len(VG.nodes) == 2 @@ -108,10 +122,10 @@ def test_from_gds_sample(gds: Any) -> None: with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G: with pytest.warns( - UserWarning, - match=re.escape( - "The 'hello' projection's node count (11000) exceeds `max_node_count` (10000), so subsampling will be applied. Increase `max_node_count` if needed" - ), + UserWarning, + match=re.escape( + "The 'hello' projection's node count (11000) exceeds `max_node_count` (10000), so subsampling will be applied. Increase `max_node_count` if needed" + ), ): VG = from_gds(gds, G) diff --git a/python-wrapper/uv.lock b/python-wrapper/uv.lock index 3472cea..a243aff 100644 --- a/python-wrapper/uv.lock +++ b/python-wrapper/uv.lock @@ -877,17 +877,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/11/208f72084084d3f6a2ed5ebfdfc846692c3f7ad6dce65e400194924f7eed/domdf_python_tools-3.10.0-py3-none-any.whl", hash = "sha256:5e71c1be71bbcc1f881d690c8984b60e64298ec256903b3147f068bc33090c36", size = 126860, upload-time = "2025-02-12T17:34:04.093Z" }, ] -[[package]] -name = "dotenv" -version = "0.9.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "python-dotenv" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, -] - [[package]] name = "enum-tools" version = "0.13.0" @@ -2450,7 +2439,6 @@ docs = [ { name = "sphinx" }, ] notebook = [ - { name = "dotenv" }, { name = "ipykernel" }, { name = "ipywidgets" }, { name = "marimo" }, @@ -2458,6 +2446,7 @@ notebook = [ { name = "neo4j" }, { name = "palettable" }, { name = "pykernel" }, + { name = "python-dotenv" }, { name = "requests" }, { name = "snowflake-snowpark-python" }, ] @@ -2466,7 +2455,7 @@ notebook = [ requires-dist = [ { name = "anywidget", specifier = ">=0.9,<1" }, { name = "enum-tools", specifier = "==0.13.0" }, - { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1,<2" }, + { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1.20,<2" }, { name = "ipython", specifier = ">=7,<10" }, { name = "neo4j", marker = "extra == 'neo4j'" }, { name = "pandas", marker = "extra == 'pandas'", specifier = ">=2,<3" }, @@ -2500,7 +2489,6 @@ docs = [ { name = "sphinx", specifier = "==8.1.3" }, ] notebook = [ - { name = "dotenv" }, { name = "ipykernel", specifier = ">=6.29.5" }, { name = "ipywidgets", specifier = ">=8.0.0" }, { name = "marimo" }, @@ -2508,6 +2496,7 @@ notebook = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "palettable", specifier = ">=3.3.3" }, { name = "pykernel", specifier = ">=0.1.6" }, + { name = "python-dotenv" }, { name = "requests" }, { name = "snowflake-snowpark-python", specifier = "==1.42.0" }, ]