Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ __pycache__/
/gaussdb_binary/
.vscode
.venv
myenv
activate_dev.ps1
.coverage
htmlcov
.idea
Expand Down
103 changes: 84 additions & 19 deletions gaussdb/gaussdb/_connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from . import pq
from ._tz import get_tzinfo
from .conninfo import make_conninfo


class ConnectionInfo:
Expand Down Expand Up @@ -72,26 +71,74 @@ def get_parameters(self) -> dict[str, str]:
either from the connection string and parameters passed to
`~Connection.connect()` or from environment variables. The password
is never returned (you can read it using the `password` attribute).

Note:
GaussDB does not support PGconn.info attribute, uses fallback method.
"""
pyenc = self.encoding

# Get the known defaults to avoid reporting them
defaults = {
i.keyword: i.compiled
for i in pq.Conninfo.get_defaults()
if i.compiled is not None
}
# Not returned by the libq. Bug? Bet we're using SSH.
defaults.setdefault(b"channel_binding", b"prefer")
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()

return {
i.keyword.decode(pyenc): i.val.decode(pyenc)
for i in self.pgconn.info
if i.val is not None
and i.keyword != b"password"
and i.val != defaults.get(i.keyword)
}
# Check if info attribute is supported (GaussDB does not support)
try:
info = self.pgconn.info
if info is None:
return self._get_parameters_fallback()
except (AttributeError, NotImplementedError):
return self._get_parameters_fallback()

# PostgreSQL normal path
try:
# Get the known defaults to avoid reporting them
defaults = {
i.keyword: i.compiled
for i in pq.Conninfo.get_defaults()
if i.compiled is not None
}
# Not returned by the libq. Bug? Bet we're using SSH.
defaults.setdefault(b"channel_binding", b"prefer")
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()

return {
i.keyword.decode(pyenc): i.val.decode(pyenc)
for i in info
if i.val is not None
and i.keyword != b"password"
and i.val != defaults.get(i.keyword)
}
except Exception:
# Use fallback on error
return self._get_parameters_fallback()

def _get_parameters_fallback(self) -> dict[str, str]:
"""Fallback method for getting connection parameters.

When PGconn.info is not available (e.g., GaussDB),
retrieve basic connection information from other sources.
"""
params = {}

# Get available information from pgconn attributes
if self.pgconn.host:
params["host"] = self.pgconn.host.decode(self.encoding, errors="replace")

if self.pgconn.port:
params["port"] = self.pgconn.port.decode(self.encoding, errors="replace")

if self.pgconn.db:
params["dbname"] = self.pgconn.db.decode(self.encoding, errors="replace")

if self.pgconn.user:
params["user"] = self.pgconn.user.decode(self.encoding, errors="replace")

# Get other available parameters
try:
if hasattr(self.pgconn, "options") and self.pgconn.options:
params["options"] = self.pgconn.options.decode(
self.encoding, errors="replace"
)
except Exception:
pass

return params

@property
def dsn(self) -> str:
Expand All @@ -103,7 +150,25 @@ def dsn(self) -> str:
password is never returned (you can read it using the `password`
attribute).
"""
return make_conninfo(**self.get_parameters())
try:
params = self.get_parameters()
except Exception:
params = self._get_parameters_fallback()

if not params:
return ""

# Build DSN string
parts = []
for key, value in params.items():
if key == "password":
continue # Do not include password
# Escape values
if " " in value or "=" in value or "'" in value:
value = "'" + value.replace("'", "\\'") + "'"
parts.append(f"{key}={value}")

return " ".join(parts)

@property
def status(self) -> pq.ConnStatus:
Expand Down
52 changes: 52 additions & 0 deletions gaussdb/gaussdb/_oids.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,55 @@
YEAR_OID = 1038

# autogenerated: end


# =====================================================
# GaussDB OID 别名映射(PostgreSQL OID -> GaussDB 等效 OID 列表)
# =====================================================

GAUSSDB_OID_ALIASES: dict[int, list[int]] = {
# date 类型可能映射到多个 OID
DATE_OID: [DATE_OID, SMALLDATETIME_OID],
# timestamp 类型
TIMESTAMP_OID: [TIMESTAMP_OID, SMALLDATETIME_OID],
TIMESTAMPTZ_OID: [TIMESTAMPTZ_OID],
# 其他类型保持一对一映射
}


def is_compatible_oid(expected_oid: int, actual_oid: int) -> bool:
"""
检查两个 OID 是否兼容

用于 GaussDB 场景下的类型比较,考虑 OID 别名。

Args:
expected_oid: 期望的 OID
actual_oid: 实际的 OID

Returns:
是否兼容
"""
if expected_oid == actual_oid:
return True

# 检查别名映射
aliases = GAUSSDB_OID_ALIASES.get(expected_oid, [expected_oid])
return actual_oid in aliases


def get_oid_name(oid: int) -> str:
"""
获取 OID 对应的类型名称

Args:
oid: 类型 OID

Returns:
类型名称字符串
"""
# 反向查找 OID 常量名
for name, value in globals().items():
if name.endswith("_OID") and value == oid:
return name.replace("_OID", "").lower()
return f"oid_{oid}"
27 changes: 27 additions & 0 deletions gaussdb/gaussdb/_py_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@
PY_TEXT = PyFormat.TEXT


def _is_empty_value(val: Any) -> bool:
"""检测值是否为等效空值(GaussDB 兼容)"""
if val is None:
return True
if isinstance(val, (bytes, str)) and len(val) == 0:
return True
if isinstance(val, (dict, list)) and len(val) == 0:
return True
return False


def _normalize_empty_value(val: Any, normalize_to_none: bool = False) -> Any:
"""
将空值规范化处理

Args:
val: 要处理的值
normalize_to_none: 如果为 True,将空字符串/空字典等转为 None

Returns:
规范化后的值
"""
if normalize_to_none and _is_empty_value(val):
return None
return val


class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and GaussDB.
Expand Down
51 changes: 51 additions & 0 deletions gaussdb/gaussdb/_typeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from . import sql
from .abc import AdaptContext, Query
from .rows import dict_row
from ._oids import GAUSSDB_OID_ALIASES
from ._compat import TypeAlias, TypeVar
from ._typemod import TypeModifier
from ._encodings import conn_encoding
Expand Down Expand Up @@ -209,6 +210,56 @@ def get_precision(self, fmod: int) -> int | None:
def get_scale(self, fmod: int) -> int | None:
return self.typemod.get_scale(fmod)

@classmethod
def fetch_runtime_oid(cls, conn: Any, typename: str) -> int | None:
"""
运行时获取类型 OID

从数据库查询正确的 OID,处理 GaussDB 差异。

Args:
conn: 数据库连接
typename: 类型名称

Returns:
类型 OID,查询失败返回 None
"""
try:
from .connection import Connection

if isinstance(conn, Connection):
result = conn.execute(
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
).fetchone()
else:
# AsyncConnection
import asyncio

async def _fetch():
result = await conn.execute(
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
)
return await result.fetchone()

result = asyncio.run(_fetch())

return result[0] if result else None
except Exception:
return None

@classmethod
def get_compatible_oids(cls, base_oid: int) -> list[int]:
"""
获取兼容的 OID 列表

Args:
base_oid: 基础 OID

Returns:
包含基础 OID 及其别名的列表
"""
return GAUSSDB_OID_ALIASES.get(base_oid, [base_oid])


class TypesRegistry:
"""
Expand Down
51 changes: 49 additions & 2 deletions gaussdb/gaussdb/raw_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing import TYPE_CHECKING

from . import errors as e
from .abc import ConnectionType, Params, Query
from .sql import Composable
from .rows import Row
Expand All @@ -26,6 +27,19 @@


class GaussDBRawQuery(GaussDBQuery):
"""
GaussDB raw query class.

Only supports positional placeholders ($1, $2, ...), not named placeholders.
"""

# Query cache size
_CACHE_SIZE = 128

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._query_cache: dict[bytes, bytes] = {}

def convert(self, query: Query, vars: Params | None) -> None:
if isinstance(query, str):
bquery = query.encode(self._encoding)
Expand All @@ -34,14 +48,43 @@ def convert(self, query: Query, vars: Params | None) -> None:
else:
bquery = query

self.query = bquery
# Try to get from cache
if bquery in self._query_cache:
self.query = self._query_cache[bquery]
else:
# Validate query doesn't contain named placeholders
if b"%(" in bquery:
raise e.ProgrammingError(
"RawCursor does not support named placeholders (%(name)s). "
"Use positional placeholders ($1, $2, ...) instead."
)

self.query = bquery

# Cache result
if len(self._query_cache) < self._CACHE_SIZE:
self._query_cache[bquery] = bquery

self._want_formats = self._order = None
self.dump(vars)

def dump(self, vars: Params | None) -> None:
"""
Serialize parameters.

Args:
vars: Parameter sequence (must be sequence, not dict)

Raises:
TypeError: If parameters are not a sequence
"""
if vars is not None:
if not GaussDBQuery.is_params_sequence(vars):
raise TypeError("raw queries require a sequence of parameters")
raise TypeError(
"RawCursor requires a sequence of parameters (tuple or list), "
f"got {type(vars).__name__}. "
"For named parameters, use regular Cursor instead."
)
self._want_formats = [PyFormat.AUTO] * len(vars)

self.params = self._tx.dump_sequence(vars, self._want_formats)
Expand All @@ -52,6 +95,10 @@ def dump(self, vars: Params | None) -> None:
self.types = ()
self.formats = None

def clear_cache(self) -> None:
"""Clear query cache."""
self._query_cache.clear()


class RawCursorMixin(BaseCursor[ConnectionType, Row]):
_query_cls = GaussDBRawQuery
Expand Down
20 changes: 20 additions & 0 deletions gaussdb/gaussdb/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,23 @@ def _load_binary(data: Buffer, tx: Transformer) -> list[Any]:
out = [out[i : i + dim] for i in range(0, len(out), dim)]

return out


def array_equals_unordered(arr1: list[Any], arr2: list[Any]) -> bool:
"""
Compare two arrays without considering element order.

Used for GaussDB compatibility scenarios where array element order may differ.
"""
if arr1 is None and arr2 is None:
return True
if arr1 is None or arr2 is None:
return False
if len(arr1) != len(arr2):
return False

try:
return sorted(arr1) == sorted(arr2)
except TypeError:
# Elements not sortable, fall back to set comparison
return set(map(str, arr1)) == set(map(str, arr2))
Loading
Loading