From c65f42bb0ca807e0c0ddbdfb543af3f3367d8101 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 14:51:29 +0530 Subject: [PATCH 001/127] add capability to login with username and password --- dataspace_sdk/auth.py | 241 +++++++++++++++++++++++++++++++++++++++- dataspace_sdk/client.py | 55 ++++++++- docs/sdk/QUICKSTART.md | 32 +++++- 3 files changed, 315 insertions(+), 13 deletions(-) diff --git a/dataspace_sdk/auth.py b/dataspace_sdk/auth.py index 372f5d1..c04e162 100644 --- a/dataspace_sdk/auth.py +++ b/dataspace_sdk/auth.py @@ -1,5 +1,6 @@ """Authentication module for DataSpace SDK.""" +import time from typing import Any, Dict, Optional import requests @@ -10,19 +11,219 @@ class AuthClient: """Handles authentication with DataSpace API.""" - def __init__(self, base_url: str): + def __init__( + self, + base_url: str, + keycloak_url: Optional[str] = None, + keycloak_realm: Optional[str] = None, + keycloak_client_id: Optional[str] = None, + keycloak_client_secret: Optional[str] = None, + ): """ Initialize the authentication client. Args: base_url: Base URL of the DataSpace API + keycloak_url: Keycloak server URL (e.g., "https://opub-kc.civicdatalab.in") + keycloak_realm: Keycloak realm name (e.g., "DataSpace") + keycloak_client_id: Keycloak client ID (e.g., "dataspace") + keycloak_client_secret: Optional client secret for confidential clients """ self.base_url = base_url.rstrip("/") + self.keycloak_url = keycloak_url.rstrip("/") if keycloak_url else None + self.keycloak_realm = keycloak_realm + self.keycloak_client_id = keycloak_client_id + self.keycloak_client_secret = keycloak_client_secret + + # Session state self.access_token: Optional[str] = None self.refresh_token: Optional[str] = None + self.keycloak_access_token: Optional[str] = None + self.keycloak_refresh_token: Optional[str] = None + self.token_expires_at: Optional[float] = None self.user_info: Optional[Dict] = None - def login_with_keycloak(self, keycloak_token: str) -> Dict[str, Any]: + # Stored credentials for auto-relogin + self._username: Optional[str] = None + self._password: Optional[str] = None + + def login(self, username: str, password: str) -> Dict[str, Any]: + """ + Login using username and password via Keycloak. + + Args: + username: User's username or email + password: User's password + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + """ + if not all([self.keycloak_url, self.keycloak_realm, self.keycloak_client_id]): + raise DataSpaceAuthError( + "Keycloak configuration missing. Please provide keycloak_url, " + "keycloak_realm, and keycloak_client_id when initializing the client." + ) + + # Store credentials for auto-relogin + self._username = username + self._password = password + + # Get Keycloak token + keycloak_token = self._get_keycloak_token(username, password) + + # Login to DataSpace backend + return self._login_with_keycloak_token(keycloak_token) + + def _get_keycloak_token(self, username: str, password: str) -> str: + """ + Get Keycloak access token using username and password. + + Args: + username: User's username or email + password: User's password + + Returns: + Keycloak access token + + Raises: + DataSpaceAuthError: If authentication fails + """ + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "password", + "client_id": self.keycloak_client_id, + "username": username, + "password": password, + } + + if self.keycloak_client_secret: + data["client_secret"] = self.keycloak_client_secret + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + # Calculate token expiration time + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in Keycloak response") + + return self.keycloak_access_token + else: + error_data = response.json() + error_msg = error_data.get( + "error_description", + error_data.get("error", "Keycloak authentication failed"), + ) + raise DataSpaceAuthError( + f"Keycloak login failed: {error_msg}", + status_code=response.status_code, + response=error_data, + ) + except requests.RequestException as e: + raise DataSpaceAuthError(f"Network error during Keycloak authentication: {str(e)}") + + def _refresh_keycloak_token(self) -> str: + """ + Refresh Keycloak access token using refresh token. + + Returns: + New Keycloak access token + + Raises: + DataSpaceAuthError: If token refresh fails + """ + if not self.keycloak_refresh_token: + # If no refresh token, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError("No refresh token or credentials available") + + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "refresh_token", + "client_id": self.keycloak_client_id, + "refresh_token": self.keycloak_refresh_token, + } + + if self.keycloak_client_secret: + data["client_secret"] = self.keycloak_client_secret + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in refresh response") + + return self.keycloak_access_token + else: + # Refresh failed, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError("Keycloak token refresh failed") + except requests.RequestException as e: + # Network error, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError(f"Network error during token refresh: {str(e)}") + + def _ensure_valid_keycloak_token(self) -> str: + """ + Ensure we have a valid Keycloak token, refreshing if necessary. + + Returns: + Valid Keycloak access token + + Raises: + DataSpaceAuthError: If unable to get valid token + """ + # Check if token is expired or about to expire (within 30 seconds) + if ( + not self.keycloak_access_token + or not self.token_expires_at + or time.time() >= (self.token_expires_at - 30) + ): + # Token expired or about to expire, refresh it + if self.keycloak_refresh_token or (self._username and self._password): + return self._refresh_keycloak_token() + raise DataSpaceAuthError("No valid token or credentials available") + + return self.keycloak_access_token + + def _login_with_keycloak_token(self, keycloak_token: str) -> Dict[str, Any]: """ Login using a Keycloak token. @@ -140,3 +341,39 @@ def _get_auth_headers(self) -> Dict[str, str]: def is_authenticated(self) -> bool: """Check if the client is authenticated.""" return self.access_token is not None + + def ensure_authenticated(self) -> None: + """ + Ensure the client is authenticated, attempting auto-relogin if needed. + + Raises: + DataSpaceAuthError: If unable to authenticate + """ + if not self.is_authenticated(): + # Try to relogin with stored credentials + if self._username and self._password: + self.login(self._username, self._password) + else: + raise DataSpaceAuthError("Not authenticated. Please call login() first.") + + def get_valid_token(self) -> str: + """ + Get a valid access token, refreshing if necessary. + + Returns: + Valid access token + + Raises: + DataSpaceAuthError: If unable to get valid token + """ + # First ensure we have a valid Keycloak token + if self.keycloak_url and self.keycloak_realm: + keycloak_token = self._ensure_valid_keycloak_token() + # Re-login to backend with fresh Keycloak token if needed + if not self.access_token: + self._login_with_keycloak_token(keycloak_token) + + if not self.access_token: + raise DataSpaceAuthError("No access token available") + + return self.access_token diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index 93b3e50..8103a91 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -33,24 +33,67 @@ class DataSpaceClient: >>> org_usecases = client.usecases.get_organization_usecases("org-uuid") """ - def __init__(self, base_url: str): + def __init__( + self, + base_url: str, + keycloak_url: Optional[str] = None, + keycloak_realm: Optional[str] = None, + keycloak_client_id: Optional[str] = None, + keycloak_client_secret: Optional[str] = None, + ): """ Initialize the DataSpace client. Args: base_url: Base URL of the DataSpace API (e.g., "https://api.dataspace.example.com") + keycloak_url: Keycloak server URL (e.g., "https://opub-kc.civicdatalab.in") + keycloak_realm: Keycloak realm name (e.g., "DataSpace") + keycloak_client_id: Keycloak client ID (e.g., "dataspace") + keycloak_client_secret: Optional client secret for confidential clients """ self.base_url = base_url.rstrip("/") - self._auth = AuthClient(self.base_url) + self._auth = AuthClient( + self.base_url, + keycloak_url=keycloak_url, + keycloak_realm=keycloak_realm, + keycloak_client_id=keycloak_client_id, + keycloak_client_secret=keycloak_client_secret, + ) # Initialize resource clients self.datasets = DatasetClient(self.base_url, self._auth) self.aimodels = AIModelClient(self.base_url, self._auth) self.usecases = UseCaseClient(self.base_url, self._auth) - def login(self, keycloak_token: str) -> dict: + def login(self, username: str, password: str) -> dict: """ - Login using a Keycloak token. + Login using username and password. + + Args: + username: User's username or email + password: User's password + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + + Example: + >>> client = DataSpaceClient( + ... base_url="https://api.dataspace.example.com", + ... keycloak_url="https://opub-kc.civicdatalab.in", + ... keycloak_realm="DataSpace", + ... keycloak_client_id="dataspace" + ... ) + >>> user_info = client.login(username="user@example.com", password="secret") + >>> print(user_info["user"]["username"]) + """ + return self._auth.login(username, password) + + def login_with_token(self, keycloak_token: str) -> dict: + """ + Login using a pre-obtained Keycloak token. Args: keycloak_token: Valid Keycloak access token @@ -63,10 +106,10 @@ def login(self, keycloak_token: str) -> dict: Example: >>> client = DataSpaceClient(base_url="https://api.dataspace.example.com") - >>> user_info = client.login(keycloak_token="your_token") + >>> user_info = client.login_with_token(keycloak_token="your_token") >>> print(user_info["user"]["username"]) """ - return self._auth.login_with_keycloak(keycloak_token) + return self._auth._login_with_keycloak_token(keycloak_token) def refresh_token(self) -> str: """ diff --git a/docs/sdk/QUICKSTART.md b/docs/sdk/QUICKSTART.md index 915432b..6138293 100644 --- a/docs/sdk/QUICKSTART.md +++ b/docs/sdk/QUICKSTART.md @@ -21,13 +21,23 @@ pip install -e . ```python from dataspace_sdk import DataSpaceClient -# 1. Initialize the client -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +# 1. Initialize the client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) -# 2. Login with your Keycloak token -client.login(keycloak_token="your_keycloak_token") +# 2. Login with username and password +# Credentials are stored securely for automatic token refresh +user_info = client.login( + username="your-email@example.com", + password="your-password" +) +print(f"Logged in as: {user_info['user']['username']}") -# 3. Search for datasets +# 3. Search for datasets (tokens auto-refresh as needed) datasets = client.datasets.search(query="health", page_size=5) print(f"Found {datasets['total']} datasets") @@ -36,6 +46,18 @@ dataset = client.datasets.get_by_id("dataset-uuid") print(f"Dataset: {dataset['title']}") ``` +### Alternative: Login with Pre-obtained Token + +If you already have a Keycloak token: + +```python +# Initialize without Keycloak config +client = DataSpaceClient(base_url="https://dataspace.civicdatalab.in") + +# Login with token +client.login_with_token(keycloak_token="your_keycloak_token") +``` + ## Common Operations ### Search Resources From 7806a5cba6218b8e5a918559a164932c9d71c06a Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 14:56:24 +0530 Subject: [PATCH 002/127] add documentation and login example --- docs/sdk/AUTHENTICATION.md | 369 ++++++++++++++++++++++++++++ examples/username_password_login.py | 69 ++++++ 2 files changed, 438 insertions(+) create mode 100644 docs/sdk/AUTHENTICATION.md create mode 100644 examples/username_password_login.py diff --git a/docs/sdk/AUTHENTICATION.md b/docs/sdk/AUTHENTICATION.md new file mode 100644 index 0000000..65aeb13 --- /dev/null +++ b/docs/sdk/AUTHENTICATION.md @@ -0,0 +1,369 @@ +# DataSpace SDK Authentication Guide + +This guide explains how to authenticate with the DataSpace SDK using username/password with automatic token management. + +## Table of Contents + +- [Quick Start](#quick-start) +- [Authentication Methods](#authentication-methods) +- [Automatic Token Management](#automatic-token-management) +- [Configuration](#configuration) +- [Error Handling](#error-handling) +- [Best Practices](#best-practices) + +## Quick Start + +### Username/Password Login (Recommended) + +```python +from dataspace_sdk import DataSpaceClient + +# Initialize with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +# Login once - credentials stored for auto-refresh +user_info = client.login( + username="your-email@example.com", + password="your-password" +) + +# Now use the client - tokens auto-refresh! +datasets = client.datasets.search(query="health") +``` + +## Authentication Methods + +### 1. Username/Password (Recommended) + +Best for: +- Scripts and automation +- Long-running applications +- Development and testing + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +user_info = client.login( + username="user@example.com", + password="secret" +) +``` + +**Features:** +- ✅ Automatic token refresh +- ✅ Automatic re-login if refresh fails +- ✅ No manual token management needed +- ✅ Credentials stored securely in memory + +### 2. Pre-obtained Keycloak Token + +Best for: +- When you already have a token from another source +- Browser-based applications +- SSO integrations + +```python +client = DataSpaceClient(base_url="https://dataspace.civicdatalab.in") + +user_info = client.login_with_token(keycloak_token="eyJhbGci...") +``` + +**Note:** This method does NOT support automatic token refresh or re-login. + +## Automatic Token Management + +The SDK automatically handles token expiration and refresh: + +### How It Works + +1. **Login**: You provide username/password once +2. **Token Storage**: SDK stores credentials securely in memory +3. **Auto-Refresh**: When token expires (within 30 seconds), SDK automatically refreshes it +4. **Auto-Relogin**: If refresh fails, SDK automatically re-authenticates with stored credentials +5. **Transparent**: All of this happens automatically - you don't need to do anything! + +### Example: Long-Running Script + +```python +import time +from dataspace_sdk import DataSpaceClient + +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +# Login once +client.login(username="user@example.com", password="secret") + +# Run for hours - tokens auto-refresh! +while True: + datasets = client.datasets.search(query="health") + print(f"Found {len(datasets.get('results', []))} datasets") + + # Sleep for 10 minutes + time.sleep(600) + + # SDK automatically refreshes tokens as needed + # No manual intervention required! +``` + +## Configuration + +### Keycloak Configuration + +You need these details from your Keycloak setup: + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", # DataSpace API URL + keycloak_url="https://opub-kc.civicdatalab.in", # Keycloak server URL + keycloak_realm="DataSpace", # Realm name + keycloak_client_id="dataspace", # Client ID + keycloak_client_secret="optional-secret" # Only for confidential clients +) +``` + +### Finding Your Keycloak Details + +1. **Keycloak URL**: Your Keycloak server address +2. **Realm**: Usually shown in Keycloak admin console +3. **Client ID**: Found in Keycloak → Clients → Your Client +4. **Client Secret**: Only needed if client is "confidential" (check Access Type in Keycloak) + +### Environment Variables (Recommended) + +Store credentials securely: + +```python +import os +from dataspace_sdk import DataSpaceClient + +client = DataSpaceClient( + base_url=os.getenv("DATASPACE_API_URL"), + keycloak_url=os.getenv("KEYCLOAK_URL"), + keycloak_realm=os.getenv("KEYCLOAK_REALM"), + keycloak_client_id=os.getenv("KEYCLOAK_CLIENT_ID"), + keycloak_client_secret=os.getenv("KEYCLOAK_CLIENT_SECRET") # Optional +) + +client.login( + username=os.getenv("DATASPACE_USERNAME"), + password=os.getenv("DATASPACE_PASSWORD") +) +``` + +Create a `.env` file: + +```bash +DATASPACE_API_URL=https://dataspace.civicdatalab.in +KEYCLOAK_URL=https://opub-kc.civicdatalab.in +KEYCLOAK_REALM=DataSpace +KEYCLOAK_CLIENT_ID=dataspace +DATASPACE_USERNAME=your-email@example.com +DATASPACE_PASSWORD=your-password +``` + +## Error Handling + +### Authentication Errors + +```python +from dataspace_sdk import DataSpaceClient +from dataspace_sdk.exceptions import DataSpaceAuthError + +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +try: + user_info = client.login( + username="user@example.com", + password="wrong-password" + ) +except DataSpaceAuthError as e: + print(f"Login failed: {e}") + print(f"Status code: {e.status_code}") + print(f"Response: {e.response}") +``` + +### Common Errors + +| Error | Cause | Solution | +|-------|-------|----------| +| `Keycloak configuration missing` | Missing keycloak_url, realm, or client_id | Provide all required Keycloak parameters | +| `Keycloak login failed: invalid_grant` | Wrong username/password | Check credentials | +| `Keycloak login failed: invalid_client` | Wrong client_id or client requires consent | Check client_id or disable consent in Keycloak | +| `Resource not found` | Wrong Keycloak URL or realm | Verify keycloak_url and realm name | +| `Not authenticated` | Trying to use API before login | Call `client.login()` first | + +### Checking Authentication Status + +```python +# Check if authenticated +if client.is_authenticated(): + print("Authenticated!") +else: + print("Not authenticated") + +# Get current user info +user = client.user +if user: + print(f"Logged in as: {user['username']}") +``` + +## Best Practices + +### 1. Use Environment Variables + +Never hardcode credentials: + +```python +# ❌ Bad +client.login(username="user@example.com", password="secret123") + +# ✅ Good +client.login( + username=os.getenv("DATASPACE_USERNAME"), + password=os.getenv("DATASPACE_PASSWORD") +) +``` + +### 2. Login Once, Use Everywhere + +```python +# ✅ Good - Login once at startup +client = DataSpaceClient(...) +client.login(username=..., password=...) + +# Use client throughout your application +# Tokens auto-refresh! +datasets = client.datasets.search(...) +models = client.aimodels.search(...) +``` + +### 3. Handle Errors Gracefully + +```python +from dataspace_sdk.exceptions import DataSpaceAuthError + +try: + client.login(username=username, password=password) +except DataSpaceAuthError as e: + logger.error(f"Authentication failed: {e}") + # Handle error appropriately + sys.exit(1) +``` + +### 4. Use Confidential Clients for Production + +For production applications, use a confidential client with client_secret: + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace-prod", + keycloak_client_secret=os.getenv("KEYCLOAK_CLIENT_SECRET") +) +``` + +### 5. Don't Store Passwords in Code + +```python +# ❌ Bad +PASSWORD = "secret123" + +# ✅ Good - Use environment variables +password = os.getenv("DATASPACE_PASSWORD") + +# ✅ Better - Use secrets management +from your_secrets_manager import get_secret +password = get_secret("dataspace_password") +``` + +## Advanced Usage + +### Manual Token Refresh + +While automatic refresh is recommended, you can manually refresh: + +```python +# Get new access token +new_token = client.refresh_token() +print(f"New token: {new_token}") +``` + +### Get Current Access Token + +```python +# Get current token (e.g., to pass to another service) +token = client.access_token +print(f"Current token: {token}") +``` + +### Ensure Authentication + +Force authentication check and auto-relogin if needed: + +```python +# This will auto-relogin if not authenticated +client._auth.ensure_authenticated() +``` + +## Troubleshooting + +### Token Keeps Expiring + +If you're experiencing frequent token expiration: + +1. **Check token lifetime**: Keycloak admin → Realm Settings → Tokens → Access Token Lifespan +2. **Verify auto-refresh**: SDK automatically refreshes 30 seconds before expiration +3. **Check credentials**: Ensure username/password are stored correctly + +### "Resource not found" Error + +This usually means wrong Keycloak URL: + +```python +# Try with /auth prefix +keycloak_url="https://opub-kc.civicdatalab.in/auth" + +# Or without +keycloak_url="https://opub-kc.civicdatalab.in" + +# Test in browser: +# https://opub-kc.civicdatalab.in/auth/realms/DataSpace/.well-known/openid-configuration +``` + +### "Client requires user consent" Error + +In Keycloak admin console: + +1. Go to Clients → your client +2. Find "Consent Required" setting +3. Set to OFF +4. Save + +## Next Steps + +- [Quick Start Guide](QUICKSTART.md) +- [Full SDK Documentation](README.md) +- [Examples](../../examples/) +- [API Reference](API_REFERENCE.md) diff --git a/examples/username_password_login.py b/examples/username_password_login.py new file mode 100644 index 0000000..c477228 --- /dev/null +++ b/examples/username_password_login.py @@ -0,0 +1,69 @@ +""" +Example: Login with username and password with automatic token refresh. + +This example demonstrates: +1. Initializing the client with Keycloak configuration +2. Logging in with username and password +3. Automatic token refresh when tokens expire +4. Automatic re-login if refresh fails +""" + +from dataspace_sdk import DataSpaceClient +from dataspace_sdk.exceptions import DataSpaceAuthError + +# Initialize client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + # keycloak_client_secret="your-secret" # Only if using confidential client +) + +# Login with username and password +# Credentials are stored securely for automatic re-login +try: + user_info = client.login(username="your-email@example.com", password="your-password") + print(f"✓ Logged in as: {user_info['user']['username']}") + print(f"✓ Organizations: {[org['name'] for org in user_info['user'].get('organizations', [])]}") +except DataSpaceAuthError as e: + print(f"✗ Login failed: {e}") + exit(1) + +# Now you can use the client normally +# Tokens will be automatically refreshed when they expire + +# Example 1: Search datasets +print("\n--- Searching datasets ---") +datasets = client.datasets.search(query="health", page_size=5) +print(f"Found {len(datasets.get('results', []))} datasets") + +# Example 2: Get user's organization datasets +print("\n--- Getting organization datasets ---") +user_orgs = user_info["user"].get("organizations", []) +if user_orgs: + org_id = user_orgs[0]["id"] + org_datasets = client.datasets.get_organization_datasets(org_id, limit=10, offset=0) + print(f"Organization has datasets: {org_datasets}") + +# Example 3: Token automatically refreshes +# Even if you use the client after tokens expire, it will auto-refresh +print("\n--- Simulating long-running session ---") +print("The SDK will automatically refresh tokens as needed...") + +# You can also manually check authentication status +if client.is_authenticated(): + print("✓ Still authenticated") +else: + print("✗ Not authenticated") + +# Example 4: Get current user info (will auto-refresh if needed) +print("\n--- Getting user info ---") +try: + current_user = client.get_user_info() + print(f"Current user: {current_user.get('username')}") +except DataSpaceAuthError as e: + print(f"Error: {e}") + +print("\n✓ All operations completed successfully!") +print("Note: The SDK automatically handled token refresh and re-login as needed.") From 5b2235c5890fb08efec3c70422f6c2f581bd8a37 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 15:01:46 +0530 Subject: [PATCH 003/127] fix tests --- tests/test_auth.py | 54 +++++++++++++++++++++++++++++++++++++++----- tests/test_client.py | 19 +++++++++++++--- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/tests/test_auth.py b/tests/test_auth.py index c05b502..5a6676a 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -13,18 +13,60 @@ class TestAuthClient(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" self.base_url = "https://api.test.com" - self.auth_client = AuthClient(self.base_url) + self.keycloak_url = "https://keycloak.test.com" + self.keycloak_realm = "test-realm" + self.keycloak_client_id = "test-client" + self.auth_client = AuthClient( + self.base_url, + keycloak_url=self.keycloak_url, + keycloak_realm=self.keycloak_realm, + keycloak_client_id=self.keycloak_client_id, + ) def test_init(self) -> None: """Test AuthClient initialization.""" self.assertEqual(self.auth_client.base_url, self.base_url) + self.assertEqual(self.auth_client.keycloak_url, self.keycloak_url) + self.assertEqual(self.auth_client.keycloak_realm, self.keycloak_realm) + self.assertEqual(self.auth_client.keycloak_client_id, self.keycloak_client_id) self.assertIsNone(self.auth_client.access_token) self.assertIsNone(self.auth_client.refresh_token) self.assertIsNone(self.auth_client.user_info) @patch("dataspace_sdk.auth.requests.post") - def test_login_success(self, mock_post: MagicMock) -> None: - """Test successful login.""" + def test_login_with_username_password(self, mock_post: MagicMock) -> None: + """Test successful login with username/password.""" + # Mock Keycloak token response + keycloak_response = MagicMock() + keycloak_response.status_code = 200 + keycloak_response.json.return_value = { + "access_token": "keycloak_access_token", + "refresh_token": "keycloak_refresh_token", + "expires_in": 300, + } + + # Mock DataSpace backend login response + backend_response = MagicMock() + backend_response.status_code = 200 + backend_response.json.return_value = { + "access": "test_access_token", + "refresh": "test_refresh_token", + "user": {"id": "123", "username": "testuser"}, + } + + mock_post.side_effect = [keycloak_response, backend_response] + + result = self.auth_client.login("testuser", "password") + + self.assertEqual(self.auth_client.access_token, "test_access_token") + self.assertEqual(self.auth_client.refresh_token, "test_refresh_token") + self.assertIsNotNone(self.auth_client.user_info) + self.assertEqual(result["user"]["username"], "testuser") + self.assertEqual(mock_post.call_count, 2) + + @patch("dataspace_sdk.auth.requests.post") + def test_login_with_token(self, mock_post: MagicMock) -> None: + """Test successful login with Keycloak token.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -34,7 +76,7 @@ def test_login_success(self, mock_post: MagicMock) -> None: } mock_post.return_value = mock_response - result = self.auth_client.login_with_keycloak("test_keycloak_token") + result = self.auth_client._login_with_keycloak_token("test_keycloak_token") self.assertEqual(self.auth_client.access_token, "test_access_token") self.assertEqual(self.auth_client.refresh_token, "test_refresh_token") @@ -46,11 +88,11 @@ def test_login_failure(self, mock_post: MagicMock) -> None: """Test failed login.""" mock_response = MagicMock() mock_response.status_code = 401 - mock_response.json.return_value = {"error": "Invalid token"} + mock_response.json.return_value = {"error": "invalid_grant"} mock_post.return_value = mock_response with self.assertRaises(DataSpaceAuthError): - self.auth_client.login_with_keycloak("invalid_token") + self.auth_client.login("invalid_user", "invalid_password") @patch("dataspace_sdk.auth.requests.post") def test_refresh_token_success(self, mock_post: MagicMock) -> None: diff --git a/tests/test_client.py b/tests/test_client.py index 870a94f..7d27d44 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -22,15 +22,28 @@ def test_init(self) -> None: self.assertIsNotNone(self.client.aimodels) self.assertIsNotNone(self.client.usecases) - @patch("dataspace_sdk.client.AuthClient.login_with_keycloak") + @patch("dataspace_sdk.client.AuthClient.login") def test_login(self, mock_login: MagicMock) -> None: - """Test login method.""" + """Test login method with username/password.""" mock_login.return_value = { "access": "token", "user": {"id": "123", "username": "testuser"}, } - result = self.client.login("test_token") + result = self.client.login("testuser", "password") + + self.assertEqual(result["user"]["username"], "testuser") + mock_login.assert_called_once_with("testuser", "password") + + @patch("dataspace_sdk.client.AuthClient._login_with_keycloak_token") + def test_login_with_token(self, mock_login: MagicMock) -> None: + """Test login method with Keycloak token.""" + mock_login.return_value = { + "access": "token", + "user": {"id": "123", "username": "testuser"}, + } + + result = self.client.login_with_token("test_token") self.assertEqual(result["user"]["username"], "testuser") mock_login.assert_called_once_with("test_token") From b660ae6cde08c7bab930112f9bf2c0d4b913e868 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 15:29:00 +0530 Subject: [PATCH 004/127] update version management --- .github/workflows/publish-sdk.yml | 34 ++++++++++++++++++++----------- dataspace_sdk/__init__.py | 2 +- dataspace_sdk/__version__.py | 3 +++ setup.py | 11 +++++++++- 4 files changed, 36 insertions(+), 14 deletions(-) create mode 100644 dataspace_sdk/__version__.py diff --git a/.github/workflows/publish-sdk.yml b/.github/workflows/publish-sdk.yml index 5c3994e..987545c 100644 --- a/.github/workflows/publish-sdk.yml +++ b/.github/workflows/publish-sdk.yml @@ -23,6 +23,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + ref: ${{ github.ref }} - name: Set up Python uses: actions/setup-python@v4 @@ -35,13 +36,30 @@ jobs: pip install build twine pip install -e ".[dev]" - - name: Update version in setup.py + - name: Update version in __version__.py run: | - sed -i "s/version=\".*\"/version=\"${{ inputs.version }}\"/" setup.py + echo '"""Version information for DataSpace SDK."""' > dataspace_sdk/__version__.py + echo "" >> dataspace_sdk/__version__.py + echo "__version__ = \"${{ inputs.version }}\"" >> dataspace_sdk/__version__.py - - name: Update version in pyproject.toml + - name: Update version in pyproject.toml (if exists) run: | - sed -i "s/version = \".*\"/version = \"${{ inputs.version }}\"/" pyproject.toml + if [ -f pyproject.toml ]; then + sed -i "s/version = \".*\"/version = \"${{ inputs.version }}\"/" pyproject.toml + fi + + - name: Commit and push version changes + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add dataspace_sdk/__version__.py + if [ -f pyproject.toml ]; then git add pyproject.toml; fi + git commit -m "Bump SDK version to ${{ inputs.version }}" || echo "No changes to commit" + + # Get current branch name + BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) + echo "Pushing to branch: $BRANCH_NAME" + git push origin $BRANCH_NAME || echo "No changes to push" - name: Run tests run: | @@ -65,14 +83,6 @@ jobs: run: | twine upload dist/* - - name: Commit version changes - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - git add setup.py pyproject.toml - git commit -m "Bump version to ${{ inputs.version }}" || echo "No changes to commit" - git push || echo "No changes to push" - - name: Create GitHub Release uses: actions/create-release@v1 env: diff --git a/dataspace_sdk/__init__.py b/dataspace_sdk/__init__.py index bb81360..1848fc9 100644 --- a/dataspace_sdk/__init__.py +++ b/dataspace_sdk/__init__.py @@ -1,5 +1,6 @@ """DataSpace Python SDK for programmatic access to DataSpace resources.""" +from dataspace_sdk.__version__ import __version__ from dataspace_sdk.client import DataSpaceClient from dataspace_sdk.exceptions import ( DataSpaceAPIError, @@ -8,7 +9,6 @@ DataSpaceValidationError, ) -__version__ = "0.1.0" __all__ = [ "DataSpaceClient", "DataSpaceAPIError", diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py new file mode 100644 index 0000000..ae71315 --- /dev/null +++ b/dataspace_sdk/__version__.py @@ -0,0 +1,3 @@ +"""Version information for DataSpace SDK.""" + +__version__ = "0.3.0" diff --git a/setup.py b/setup.py index c92e2bb..685a874 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,22 @@ """Setup configuration for DataSpace Python SDK.""" +import os +from typing import Any, Dict + from setuptools import find_packages, setup +# Read version from __version__.py +version: Dict[str, Any] = {} +version_file = os.path.join(os.path.dirname(__file__), "dataspace_sdk", "__version__.py") +with open(version_file, "r", encoding="utf-8") as f: + exec(f.read(), version) + with open("docs/sdk/README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setup( name="dataspace-sdk", - version="0.1.0", + version=version["__version__"], author="CivicDataLab", author_email="tech@civicdatalab.in", description="Python SDK for DataSpace API - programmatic access to datasets, AI models, and use cases", From 85cd95ad4da992184034ec6c9a64051d79b870ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 20 Nov 2025 10:01:53 +0000 Subject: [PATCH 005/127] Bump SDK version to 0.3.1 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index ae71315..c79f072 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/pyproject.toml b/pyproject.toml index 1aac447..1e72842 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.1.0" +version = "0.3.1" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "3.8" +python_version = "0.3.1" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 54e8e3128b6889edf4f1d6b21b209c70e5fc3411 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 16:23:20 +0530 Subject: [PATCH 006/127] Remove accidentally committed dummydistrict.csv file --- files/public/resources/dummydistrict.csv | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 files/public/resources/dummydistrict.csv diff --git a/files/public/resources/dummydistrict.csv b/files/public/resources/dummydistrict.csv deleted file mode 100644 index 1a8dd45..0000000 --- a/files/public/resources/dummydistrict.csv +++ /dev/null @@ -1,3 +0,0 @@ -district,Value,,,,, -DHUBRI,10,,,,, -South Salmara Mancachar,30,,,,, From c3fa8299c77e085bd6b1417e5005fda1d7d7e368 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 16:32:59 +0530 Subject: [PATCH 007/127] add tenacity to requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index e1bd919..ba7464d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -116,3 +116,5 @@ djangorestframework-stubs==3.14.5 # Matched with DRF version # Activity stream for tracking user actions django-activity-stream==2.0.0 + +tenacity==9.1.2 From b9b3c3e94c1fa5777b9b20d3c7fb2a9c03616e0c Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 16:36:04 +0530 Subject: [PATCH 008/127] add missing requriements --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index ba7464d..13b0a3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -118,3 +118,7 @@ djangorestframework-stubs==3.14.5 # Matched with DRF version django-activity-stream==2.0.0 tenacity==9.1.2 +torch==2.9.0 +transformers==4.57.1 +sentencepiece==0.2.1 +accelerate==1.11.0 From 849934fb5ec20f0b52b45b20293d91be1da3da52 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 17:45:38 +0530 Subject: [PATCH 009/127] resolve dependency conflicts --- requirements.txt | 68 ++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/requirements.txt b/requirements.txt index 13b0a3e..bcfd71d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,26 +18,26 @@ elasticsearch-dsl==8.12.0 googleapis-common-protos==1.65.0 graphql-core==3.2.3 gunicorn==23.0.0 -grpcio==1.62.1 +grpcio==1.68.1 idna==3.6 importlib-metadata==6.11.0 numpy==1.26.4 -opentelemetry-api==1.23.0 -opentelemetry-exporter-otlp==1.23.0 -opentelemetry-exporter-otlp-proto-common==1.23.0 -opentelemetry-exporter-otlp-proto-grpc==1.23.0 -opentelemetry-exporter-otlp-proto-http==1.23.0 -opentelemetry-instrumentation==0.44b0 -opentelemetry-instrumentation-django==0.44b0 -opentelemetry-instrumentation-wsgi==0.44b0 -opentelemetry-instrumentation-redis==0.44b0 -opentelemetry-instrumentation-elasticsearch==0.44b0 -opentelemetry-instrumentation-requests==0.44b0 -opentelemetry-instrumentation-sqlalchemy==0.44b0 -opentelemetry-proto==1.23.0 -opentelemetry-sdk==1.23.0 -opentelemetry-semantic-conventions==0.44b0 -opentelemetry-util-http==0.44b0 +opentelemetry-api==1.28.2 +opentelemetry-exporter-otlp==1.28.2 +opentelemetry-exporter-otlp-proto-common==1.28.2 +opentelemetry-exporter-otlp-proto-grpc==1.28.2 +opentelemetry-exporter-otlp-proto-http==1.28.2 +opentelemetry-instrumentation==0.49b2 +opentelemetry-instrumentation-django==0.49b2 +opentelemetry-instrumentation-wsgi==0.49b2 +opentelemetry-instrumentation-redis==0.49b2 +opentelemetry-instrumentation-elasticsearch==0.49b2 +opentelemetry-instrumentation-requests==0.49b2 +opentelemetry-instrumentation-sqlalchemy==0.49b2 +opentelemetry-proto==1.28.2 +opentelemetry-sdk==1.28.2 +opentelemetry-semantic-conventions==0.49b2 +opentelemetry-util-http==0.49b2 packaging==24.0 pandas==2.2.2 openpyxl==3.1.2 # For Excel file support @@ -45,13 +45,13 @@ odfpy==1.4.1 # For ODS file support pyarrow==15.0.0 # For Parquet and Feather file support xlrd==2.0.1 # For Excel file support pillow==10.4.0 -protobuf==4.25.3 -pydantic==2.6.1 # Slightly downgraded for compatibility -pydantic_core==2.16.2 # Matched with pydantic +protobuf==5.28.3 +pydantic==2.11.7 # Updated for deepeval and ML packages compatibility +pydantic_core==2.33.2 # Matched with pydantic pyecharts==2.0.3 pyecharts-snapshot @ git+https://github.com/Deepthi-Chand/pyecharts-snapshot.git@8d6cadd055db6c919a1447064185d00d1b30ce01 python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 +python-dotenv==1.1.1 python-magic==0.4.27 pytz==2024.1 requests==2.31.0 @@ -60,27 +60,27 @@ six==1.16.0 sqlparse==0.4.4 strawberry-graphql==0.225.1 strawberry-graphql-django==0.37.1 -typing-extensions==4.9.0 # Fixed version that satisfies most dependencies +typing-extensions==4.14.0 tzdata==2024.1 urllib3==2.2.1 -uvicorn==0.27.1 # Downgraded for better compatibility +uvicorn==0.27.1 wrapt==1.16.0 zipp==3.18.1 snapshot-selenium==0.0.2 -selenium==4.16.0 # Downgraded for compatibility +selenium==4.16.0 # Security and Environment Management python-decouple==3.8 -django-debug-toolbar==4.2.0 # Downgraded for better Django compatibility +django-debug-toolbar==4.2.0 django-ratelimit==4.1.0 # Caching and Performance -django-redis==5.4.0 # Added for Redis caching -redis==5.0.1 # Added Redis client +django-redis==5.4.0 +redis==5.0.1 django-cacheops==7.0.2 # Code Quality and Type Checking -mypy==1.7.1 # Downgraded for compatibility +mypy==1.7.1 black==24.2.0 isort==5.13.2 pre-commit==3.6.2 @@ -88,10 +88,10 @@ flake8==7.0.0 pandas-stubs==2.2.0.240218 # Matched with pandas version # Data Comparison and Versioning -deepdiff==6.7.1 # For intelligent version detection +deepdiff==6.7.1 # Logging -structlog==24.1.0 # Downgraded for better compatibility +structlog==24.1.0 # Development and Debugging ipython==8.22.1 @@ -99,17 +99,17 @@ ipython==8.22.1 # API Documentation drf-yasg==1.21.7 drf-yasg[validation]==1.21.7 -djangorestframework-simplejwt==5.3.1 # Downgraded for better DRF compatibility +djangorestframework-simplejwt==5.3.1 # Keycloak Integration python-keycloak==5.5.1 django-keycloak-auth==1.0.0 # API Versioning and Throttling -django-filter==23.5 # Downgraded for better Django compatibility +django-filter==23.5 -django-stubs==4.2.7 # Downgraded for better compatibility -djangorestframework-stubs==3.14.5 # Matched with DRF version +django-stubs==4.2.7 +djangorestframework-stubs==3.14.5 #whitenoise for managing static files whitenoise==6.9.0 From 0a82bfdc95ba86efa1ce77b6b495c18536481355 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 17:57:01 +0530 Subject: [PATCH 010/127] update typing_extensions --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bcfd71d..f232a46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,7 +60,7 @@ six==1.16.0 sqlparse==0.4.4 strawberry-graphql==0.225.1 strawberry-graphql-django==0.37.1 -typing-extensions==4.14.0 +typing-extensions==4.15.0 tzdata==2024.1 urllib3==2.2.1 uvicorn==0.27.1 From 7b0061c2591b86e049498a0831a45fb2876c5d79 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 18:03:19 +0530 Subject: [PATCH 011/127] update strawberry version --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index f232a46..f673797 100644 --- a/requirements.txt +++ b/requirements.txt @@ -58,8 +58,8 @@ requests==2.31.0 sqlalchemy==2.0.39 six==1.16.0 sqlparse==0.4.4 -strawberry-graphql==0.225.1 -strawberry-graphql-django==0.37.1 +strawberry-graphql==0.243.0 +strawberry-graphql-django==0.47.2 typing-extensions==4.15.0 tzdata==2024.1 urllib3==2.2.1 From 5ef6dc230e484608fb96fafa678638a880537e9d Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 20 Nov 2025 18:17:34 +0530 Subject: [PATCH 012/127] update pydantic and strawberry to match torch --- requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index f673797..56a9731 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,8 +46,8 @@ pyarrow==15.0.0 # For Parquet and Feather file support xlrd==2.0.1 # For Excel file support pillow==10.4.0 protobuf==5.28.3 -pydantic==2.11.7 # Updated for deepeval and ML packages compatibility -pydantic_core==2.33.2 # Matched with pydantic +pydantic==2.9.2 +pydantic_core==2.23.4 # Matched with pydantic pyecharts==2.0.3 pyecharts-snapshot @ git+https://github.com/Deepthi-Chand/pyecharts-snapshot.git@8d6cadd055db6c919a1447064185d00d1b30ce01 python-dateutil==2.9.0.post0 @@ -58,8 +58,8 @@ requests==2.31.0 sqlalchemy==2.0.39 six==1.16.0 sqlparse==0.4.4 -strawberry-graphql==0.243.0 -strawberry-graphql-django==0.47.2 +strawberry-graphql==0.235.2 +strawberry-graphql-django==0.42.0 typing-extensions==4.15.0 tzdata==2024.1 urllib3==2.2.1 From 898b3f6bb48e1c7e250ec2c9c38c2c6f9bfc7a93 Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 21 Nov 2025 12:45:07 +0530 Subject: [PATCH 013/127] update docker install system dependencies without mixing old Debian repos --- Dockerfile | 45 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 30967ee..0579e6f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,47 @@ FROM python:3.10 ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONUNBUFFERED=1 -RUN echo 'deb http://archive.debian.org/debian stretch main contrib non-free' >> /etc/apt/sources.list && \ - apt-get update && \ + +# Install system dependencies without mixing old Debian repos +RUN apt-get update && \ apt-get autoremove -y && \ - apt-get install -y libssl1.0-dev curl git nano wget && \ - apt-get install -y gconf-service libasound2 libatk1.0-0 libc6 libcairo2 libcups2 libdbus-1-3 libexpat1 libfontconfig1 libgcc1 libgconf-2-4 libgdk-pixbuf2.0-0 libglib2.0-0 libgtk-3-0 libnspr4 libpango-1.0-0 libpangocairo-1.0-0 libstdc++6 libx11-6 libx11-xcb1 libxcb1 libxcomposite1 libxcursor1 libxdamage1 libxext6 libxfixes3 libxi6 libxrandr2 libxrender1 libxss1 libxtst6 ca-certificates fonts-liberation libappindicator1 libnss3 lsb-release xdg-utils wget && \ - rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/partial/* + apt-get install -y \ + curl \ + git \ + nano \ + wget \ + ca-certificates \ + fonts-liberation \ + libasound2 \ + libatk1.0-0 \ + libcairo2 \ + libcups2 \ + libdbus-1-3 \ + libexpat1 \ + libfontconfig1 \ + libgdk-pixbuf2.0-0 \ + libglib2.0-0 \ + libgtk-3-0 \ + libnspr4 \ + libnss3 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libx11-6 \ + libx11-xcb1 \ + libxcb1 \ + libxcomposite1 \ + libxcursor1 \ + libxdamage1 \ + libxext6 \ + libxfixes3 \ + libxi6 \ + libxrandr2 \ + libxrender1 \ + libxss1 \ + libxtst6 \ + lsb-release \ + xdg-utils && \ + rm -rf /var/lib/apt/lists/* WORKDIR /code From 2c02851d831652c57a9afa339628fb4d00554079 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 25 Nov 2025 19:55:34 +0530 Subject: [PATCH 014/127] send user detaisl if introspection fails --- Dockerfile | 4 ++-- api/utils/keycloak_utils.py | 38 ++++++++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0579e6f..19a0882 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.10 ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONUNBUFFERED=1 -# Install system dependencies without mixing old Debian repos +# Install system dependencies RUN apt-get update && \ apt-get autoremove -y && \ apt-get install -y \ @@ -19,7 +19,7 @@ RUN apt-get update && \ libdbus-1-3 \ libexpat1 \ libfontconfig1 \ - libgdk-pixbuf2.0-0 \ + libgdk-pixbuf-2.0-0 \ libglib2.0-0 \ libgtk-3-0 \ libnspr4 \ diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 2daa5eb..5ffacfd 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -71,9 +71,33 @@ def validate_token(self, token: str) -> Dict[str, Any]: logger.warning("Token is not active") return {} - # Get user info from the token - user_info = self.keycloak_openid.userinfo(token) - return user_info + # Try to get user info from the userinfo endpoint + # If that fails (403), fall back to token introspection data + try: + user_info = self.keycloak_openid.userinfo(token) + return user_info + except KeycloakError as userinfo_error: + # If userinfo fails (e.g., 403), extract user info from token introspection + logger.warning( + f"Userinfo endpoint failed ({userinfo_error}), using token introspection data" + ) + + # Build user info from introspection response + user_info = { + "sub": token_info.get("sub"), + "preferred_username": token_info.get("username") + or token_info.get("preferred_username"), + "email": token_info.get("email"), + "email_verified": token_info.get("email_verified", False), + "name": token_info.get("name"), + "given_name": token_info.get("given_name"), + "family_name": token_info.get("family_name"), + } + + # Remove None values + user_info = {k: v for k, v in user_info.items() if v is not None} + return user_info + except KeycloakError as e: logger.error(f"Error validating token: {e}") return {} @@ -129,9 +153,7 @@ def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: if len(parts) >= 3: org_id = parts[1] role_name = parts[2] - organizations.append( - {"organization_id": org_id, "role": role_name} - ) + organizations.append({"organization_id": org_id, "role": role_name}) return organizations except KeycloakError as e: @@ -198,9 +220,7 @@ def sync_user_from_keycloak( # Process organizations from Keycloak for org_info in organizations: org_id = org_info.get("organization_id") - role = org_info.get( - "role", "viewer" - ) # Default to viewer if role not specified + role = org_info.get("role", "viewer") # Default to viewer if role not specified # Try to get the organization try: From 8111795b466319c6b0d3ed76729e2a22e9cc3670 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 25 Nov 2025 20:06:41 +0530 Subject: [PATCH 015/127] Fix: Pass token_info dict instead of token string to get_user_roles --- api/utils/keycloak_utils.py | 9 +++------ api/views/auth.py | 7 +++++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 5ffacfd..1c339c9 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -122,22 +122,19 @@ def get_user_roles(self, token_info: dict) -> list[str]: return roles - def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: + def get_user_organizations(self, token_info: dict) -> List[Dict[str, Any]]: """ - Get the organizations a user belongs to from their token. + Get the organizations a user belongs to from their token info. This assumes that organization information is stored in the token as client roles or in user attributes. Args: - token: The user's token + token_info: The decoded token information Returns: List of organization information """ try: - # Decode the token to get user info - token_info = self.keycloak_openid.decode_token(token) - # Get organization info from resource_access or attributes # This implementation depends on how organizations are represented in Keycloak # This is a simplified example - adjust based on your Keycloak configuration diff --git a/api/views/auth.py b/api/views/auth.py index c5c3bb2..6a90f26 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -32,9 +32,12 @@ def post(self, request: Request) -> Response: status=status.HTTP_401_UNAUTHORIZED, ) + # Get token introspection data for roles and organizations + token_info = keycloak_manager.keycloak_openid.introspect(keycloak_token) + # Get user roles and organizations from the token - roles = keycloak_manager.get_user_roles(keycloak_token) - organizations = keycloak_manager.get_user_organizations(keycloak_token) + roles = keycloak_manager.get_user_roles(token_info) + organizations = keycloak_manager.get_user_organizations(token_info) # Sync the user information with our database user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) From deaa938c76fc3794b2c1bb4d4d82a413ca554a72 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 11:42:55 +0530 Subject: [PATCH 016/127] Add separate methods for token_info to avoid breaking middleware --- api/utils/keycloak_utils.py | 82 +++++++++++++++++++++++++++++-- api/views/auth.py | 6 +-- authorization/middleware_utils.py | 16 ++---- 3 files changed, 84 insertions(+), 20 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 1c339c9..b6b30e6 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -102,9 +102,15 @@ def validate_token(self, token: str) -> Dict[str, Any]: logger.error(f"Error validating token: {e}") return {} - def get_user_roles(self, token_info: dict) -> list[str]: + def get_user_roles_from_token_info(self, token_info: dict) -> list[str]: """ - Extract roles from a Keycloak token. + Extract roles from token introspection data. + + Args: + token_info: Token introspection response + + Returns: + List of role names """ roles: list[str] = [] @@ -122,19 +128,85 @@ def get_user_roles(self, token_info: dict) -> list[str]: return roles - def get_user_organizations(self, token_info: dict) -> List[Dict[str, Any]]: + def get_user_organizations_from_token_info(self, token_info: dict) -> List[Dict[str, Any]]: """ - Get the organizations a user belongs to from their token info. + Get organizations from token introspection data. + + Args: + token_info: Token introspection response + + Returns: + List of organization information + """ + try: + # Get organization info from resource_access or attributes + resource_access = token_info.get("resource_access", {}) + client_roles = resource_access.get(self.client_id, {}).get("roles", []) + + # Extract organization info from roles + organizations = [] + for role in client_roles: + if role.startswith("org_"): + parts = role.split("_") + if len(parts) >= 3: + org_id = parts[1] + role_name = parts[2] + organizations.append({"organization_id": org_id, "role": role_name}) + + return organizations + except Exception as e: + logger.error(f"Error getting user organizations: {e}") + return [] + + def get_user_roles(self, token: str) -> list[str]: + """ + Extract roles from a Keycloak token. + + Args: + token: The user's token + + Returns: + List of role names + """ + try: + # Decode the token to get user info + token_info = self.keycloak_openid.decode_token(token) + + roles: list[str] = [] + + # Extract realm roles + realm_access = token_info.get("realm_access", {}) + if realm_access and "roles" in realm_access: + roles.extend(realm_access["roles"]) # type: ignore[no-any-return] + + # Extract client roles + resource_access = token_info.get("resource_access", {}) + client_id = settings.KEYCLOAK_CLIENT_ID + if resource_access and client_id in resource_access: + client_roles = resource_access[client_id].get("roles", []) + roles.extend(client_roles) + + return roles + except Exception as e: + logger.error(f"Error getting user roles: {e}") + return [] + + def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: + """ + Get the organizations a user belongs to from their token. This assumes that organization information is stored in the token as client roles or in user attributes. Args: - token_info: The decoded token information + token: The user's token Returns: List of organization information """ try: + # Decode the token to get user info + token_info = self.keycloak_openid.decode_token(token) + # Get organization info from resource_access or attributes # This implementation depends on how organizations are represented in Keycloak # This is a simplified example - adjust based on your Keycloak configuration diff --git a/api/views/auth.py b/api/views/auth.py index 6a90f26..cc3e769 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -35,9 +35,9 @@ def post(self, request: Request) -> Response: # Get token introspection data for roles and organizations token_info = keycloak_manager.keycloak_openid.introspect(keycloak_token) - # Get user roles and organizations from the token - roles = keycloak_manager.get_user_roles(token_info) - organizations = keycloak_manager.get_user_organizations(token_info) + # Get user roles and organizations from the token introspection data + roles = keycloak_manager.get_user_roles_from_token_info(token_info) + organizations = keycloak_manager.get_user_organizations_from_token_info(token_info) # Sync the user information with our database user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) diff --git a/authorization/middleware_utils.py b/authorization/middleware_utils.py index fa70a62..de0433f 100644 --- a/authorization/middleware_utils.py +++ b/authorization/middleware_utils.py @@ -50,9 +50,7 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: # Use the raw header value, but check if it might be a raw token # (no 'Bearer ' prefix but still a valid JWT format) token = auth_header - logger.debug( - f"Using raw Authorization header as token, length: {len(token)}" - ) + logger.debug(f"Using raw Authorization header as token, length: {len(token)}") # If no token found, return anonymous user if not token: @@ -91,14 +89,10 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: return cast(User, AnonymousUser()) # Log the user info for debugging - logger.debug( - f"User info from token: {user_info.keys() if user_info else 'None'}" - ) + logger.debug(f"User info from token: {user_info.keys() if user_info else 'None'}") logger.debug(f"User sub: {user_info.get('sub', 'None')}") logger.debug(f"User email: {user_info.get('email', 'None')}") - logger.debug( - f"User preferred_username: {user_info.get('preferred_username', 'None')}" - ) + logger.debug(f"User preferred_username: {user_info.get('preferred_username', 'None')}") # Get user roles and organizations from the token roles = keycloak_manager.get_user_roles(token) @@ -113,9 +107,7 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.warning("User synchronization failed, returning anonymous user") return cast(User, AnonymousUser()) - logger.debug( - f"Successfully authenticated user: {user.username} (ID: {user.id})" - ) + logger.debug(f"Successfully authenticated user: {user.username} (ID: {user.id})") # Return the authenticated user logger.debug(f"Returning authenticated user: {user.username}") From 5da0eb648974e7ad2b90e57fd50d2cfc85d3b3c9 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 11:48:52 +0530 Subject: [PATCH 017/127] Add Django JWT authentication support in middleware --- authorization/middleware_utils.py | 35 ++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/authorization/middleware_utils.py b/authorization/middleware_utils.py index de0433f..8ec6e88 100644 --- a/authorization/middleware_utils.py +++ b/authorization/middleware_utils.py @@ -7,6 +7,8 @@ from django.http import HttpRequest, HttpResponse from django.utils.functional import SimpleLazyObject from rest_framework.request import Request +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError +from rest_framework_simplejwt.tokens import AccessToken from api.utils.debug_utils import debug_auth_headers, debug_token_validation from authorization.keycloak import keycloak_manager @@ -68,7 +70,28 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: # For debugging, print the raw token logger.debug(f"Raw token: {token}") + # First, try to validate as Django JWT token try: + logger.debug("Attempting to validate as Django JWT token") + access_token = AccessToken(token) + user_id = access_token.get("user_id") + + if user_id: + logger.debug(f"Valid Django JWT token for user_id: {user_id}") + try: + user = User.objects.get(id=user_id) + logger.debug(f"Successfully authenticated user via Django JWT: {user.username}") + return user + except User.DoesNotExist: + logger.warning(f"User with id {user_id} not found in database") + except (TokenError, InvalidToken) as e: + logger.debug(f"Not a valid Django JWT token: {e}, trying Keycloak validation") + except Exception as e: + logger.debug(f"Error validating Django JWT: {e}, trying Keycloak validation") + + # If Django JWT validation failed, try Keycloak token validation + try: + logger.debug("Attempting to validate as Keycloak token") # Try direct validation without any complex logic user_info = keycloak_manager.validate_token(token) @@ -102,16 +125,18 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.debug(f"User organizations from token: {organizations}") # Sync the user information with our database - user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) - if not user: + synced_user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) + if not synced_user: logger.warning("User synchronization failed, returning anonymous user") return cast(User, AnonymousUser()) - logger.debug(f"Successfully authenticated user: {user.username} (ID: {user.id})") + logger.debug( + f"Successfully authenticated user: {synced_user.username} (ID: {synced_user.id})" + ) # Return the authenticated user - logger.debug(f"Returning authenticated user: {user.username}") - return user + logger.debug(f"Returning authenticated user: {synced_user.username}") + return synced_user except Exception as e: logger.error(f"Error in get_user_from_keycloak_token: {str(e)}") return cast(User, AnonymousUser()) From ca3935531bb1abd27738953a5aaacf3b3f29fe0c Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 12:05:36 +0530 Subject: [PATCH 018/127] Consolidate KeycloakManager to single file in api/utils/keycloak_utils.py --- api/utils/keycloak_utils.py | 129 +++++++- authorization/authentication.py | 6 +- authorization/backends.py | 11 +- authorization/keycloak.py | 525 +----------------------------- authorization/middleware_utils.py | 2 +- authorization/schema/mutation.py | 22 +- authorization/views.py | 10 +- 7 files changed, 147 insertions(+), 558 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index b6b30e6..6bdb335 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -56,7 +56,7 @@ def get_token(self, username: str, password: str) -> Dict[str, Any]: def validate_token(self, token: str) -> Dict[str, Any]: """ - Validate a Keycloak token and return the user info. + Validate a token (Django JWT or Keycloak) and return the user info. Args: token: The token to validate @@ -64,6 +64,41 @@ def validate_token(self, token: str) -> Dict[str, Any]: Returns: Dict containing the user information """ + from rest_framework_simplejwt.exceptions import InvalidToken, TokenError + from rest_framework_simplejwt.tokens import AccessToken + + # First, try to validate as Django JWT token + try: + logger.debug("Attempting to validate as Django JWT token") + access_token = AccessToken(token) # type: ignore[arg-type] + user_id = access_token.get("user_id") + + if user_id: + logger.debug(f"Valid Django JWT token for user_id: {user_id}") + try: + from authorization.models import User + + user = User.objects.get(id=user_id) + # Return user info in Keycloak format + return { + "sub": ( + str(user.keycloak_id) + if hasattr(user, "keycloak_id") and user.keycloak_id + else str(user.id) + ), + "preferred_username": user.username, + "email": user.email, + "given_name": user.first_name, + "family_name": user.last_name, + } + except User.DoesNotExist: + logger.warning(f"User with id {user_id} not found in database") + except (TokenError, InvalidToken) as e: + logger.debug(f"Not a valid Django JWT token: {e}, trying Keycloak validation") + except Exception as e: + logger.debug(f"Error validating Django JWT: {e}, trying Keycloak validation") + + # If Django JWT validation failed, try Keycloak token validation try: # Verify the token is valid token_info = self.keycloak_openid.introspect(token) @@ -229,6 +264,98 @@ def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: logger.error(f"Error getting user organizations: {e}") return [] + def update_user_in_keycloak(self, user: User) -> bool: + """Update user details in Keycloak using admin credentials.""" + if not user.keycloak_id: + logger.warning("Cannot update user in Keycloak: No keycloak_id", user_id=str(user.id)) + return False + + try: + # Get admin credentials from settings + admin_username = getattr(settings, "KEYCLOAK_ADMIN_USERNAME", "") + admin_password = getattr(settings, "KEYCLOAK_ADMIN_PASSWORD", "") + # Log credential presence (not the actual values) + logger.info( + "Admin credentials check", + username_present=bool(admin_username), + password_present=bool(admin_password), + ) + + if not admin_username or not admin_password: + logger.error("Keycloak admin credentials not configured") + return False + + from keycloak import KeycloakOpenID + + # First get an admin token directly + keycloak_openid = KeycloakOpenID( + server_url=self.server_url, + client_id="admin-cli", # Special client for admin operations + realm_name="master", # Admin users are in master realm + verify=True, + ) + + # Get token + try: + token = keycloak_openid.token( + username=admin_username, + password=admin_password, + grant_type="password", + ) + access_token = token.get("access_token") + + if not access_token: + logger.error("Failed to get admin access token") + return False + + # Now use the token to update the user + import requests + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + user_data = { + "firstName": user.first_name, + "lastName": user.last_name, + "email": user.email, + "emailVerified": True, + } + + # Direct API call to update user + base_url = self.server_url.rstrip("/") # Remove any trailing slash + response = requests.put( + f"{base_url}/admin/realms/{self.realm}/users/{user.keycloak_id}", + headers=headers, + json=user_data, + ) + + if response.status_code == 204: # Success for this endpoint + logger.info( + "Successfully updated user in Keycloak", + user_id=str(user.id), + keycloak_id=user.keycloak_id, + ) + return True + else: + logger.error( + f"Failed to update user in Keycloak: {response.status_code}: {response.text}", + user_id=str(user.id), + ) + return False + + except Exception as token_error: + logger.error( + f"Error getting admin token: {str(token_error)}", + user_id=str(user.id), + ) + return False + + except Exception as e: + logger.error(f"Error updating user in Keycloak: {str(e)}", user_id=str(user.id)) + return False + @transaction.atomic def sync_user_from_keycloak( self, diff --git a/authorization/authentication.py b/authorization/authentication.py index 49c6655..93c2b5e 100644 --- a/authorization/authentication.py +++ b/authorization/authentication.py @@ -5,7 +5,7 @@ from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User @@ -41,9 +41,7 @@ def authenticate(self, request: Request) -> Optional[Tuple[User, str]]: # Ensure we have a subject ID in the token if not user_info.get("sub"): - raise AuthenticationFailed( - "Token validation succeeded but missing subject ID" - ) + raise AuthenticationFailed("Token validation succeeded but missing subject ID") # Get user roles and organizations from the token roles = keycloak_manager.get_user_roles(token) diff --git a/authorization/backends.py b/authorization/backends.py index 1da2796..13792a9 100644 --- a/authorization/backends.py +++ b/authorization/backends.py @@ -4,7 +4,7 @@ from django.contrib.auth.backends import ModelBackend from django.http import HttpRequest -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User @@ -14,10 +14,7 @@ class KeycloakAuthenticationBackend(ModelBackend): """ def authenticate( # type: ignore[override] - self, - request: Optional[HttpRequest] = None, - token: Optional[str] = None, - **kwargs: Any + self, request: Optional[HttpRequest] = None, token: Optional[str] = None, **kwargs: Any ) -> Optional[User]: """ Authenticate a user based on a Keycloak token. @@ -44,9 +41,7 @@ def authenticate( # type: ignore[override] # Get user roles and organizations from the token roles: List[str] = keycloak_manager.get_user_roles(token) - organizations: List[Dict[str, Any]] = keycloak_manager.get_user_organizations( - token - ) + organizations: List[Dict[str, Any]] = keycloak_manager.get_user_organizations(token) # Sync the user information with our database user: Optional[User] = keycloak_manager.sync_user_from_keycloak( diff --git a/authorization/keycloak.py b/authorization/keycloak.py index 6ea56e5..a42ce29 100644 --- a/authorization/keycloak.py +++ b/authorization/keycloak.py @@ -1,523 +1,6 @@ -from typing import Any, Dict, List, Optional, Type, TypeVar, cast +"""DEPRECATED: This module is deprecated. Use api.utils.keycloak_utils instead.""" -import structlog -from django.conf import settings -from django.db import transaction -from keycloak import KeycloakAdmin, KeycloakOpenID -from keycloak.exceptions import KeycloakError +# Import from the new location for backward compatibility +from api.utils.keycloak_utils import KeycloakManager, keycloak_manager -from api.models import Organization -from authorization.models import OrganizationMembership, Role, User - -logger = structlog.getLogger(__name__) - -# Type variables for model classes -T = TypeVar("T") - - -class KeycloakManager: - """ - Utility class to manage Keycloak integration with Django. - Handles token validation, user synchronization, and role mapping. - """ - - def __init__(self) -> None: - import structlog - - logger = structlog.getLogger(__name__) - - self.server_url: str = settings.KEYCLOAK_SERVER_URL - self.realm: str = settings.KEYCLOAK_REALM - self.client_id: str = settings.KEYCLOAK_CLIENT_ID - self.client_secret: str = settings.KEYCLOAK_CLIENT_SECRET - - # Log Keycloak connection details (without secrets) - logger.debug( - f"Initializing Keycloak connection to {self.server_url} " - f"for realm {self.realm} and client {self.client_id}" - ) - - try: - self.keycloak_openid: KeycloakOpenID = KeycloakOpenID( - server_url=self.server_url, - client_id=self.client_id, - realm_name=self.realm, - client_secret_key=self.client_secret, - ) - - logger.debug("Keycloak client initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize Keycloak client: {e}") - # Use cast to satisfy the type checker - self.keycloak_openid = cast(KeycloakOpenID, object()) - - def get_token(self, username: str, password: str) -> Dict[str, Any]: - """ - Get a Keycloak token for a user. - - Args: - username: The username - password: The password - - Returns: - Dict containing the token information - """ - try: - return self.keycloak_openid.token(username, password) - except KeycloakError as e: - logger.error(f"Error getting token: {e}") - raise - - def validate_token(self, token: str) -> Dict[str, Any]: - """ - Validate a Keycloak token and return the user info. - Only validates by contacting Keycloak directly - no local validation. - - Args: - token: The token to validate - - Returns: - Dict containing the user information or empty dict if validation fails - """ - import structlog - - logger = structlog.getLogger(__name__) - - # Log token for debugging - logger.debug(f"Validating token of length: {len(token)}") - - # Only try to contact Keycloak directly - don't create users from local token decoding - try: - logger.debug("Attempting to get user info from Keycloak") - user_info = self.keycloak_openid.userinfo(token) - if user_info and isinstance(user_info, dict): - logger.debug("Successfully retrieved user info from Keycloak") - logger.debug(f"User info: {user_info}") - return user_info - else: - logger.warning("Keycloak returned empty or invalid user info") - return {} - except Exception as e: - logger.warning(f"Failed to get user info from Keycloak: {e}") - return {} - - def get_user_roles(self, token: str) -> List[str]: - """ - Get the roles for a user from their token. - - Args: - token: The user's token - - Returns: - List of role names - """ - import structlog - from django.conf import settings - - logger = structlog.getLogger(__name__) - - # Get roles directly from token - logger.debug("Extracting roles from token") - - logger.debug(f"Getting roles from token of length: {len(token)}") - - try: - # Decode the token to get the roles - try: - token_info: Dict[str, Any] = self.keycloak_openid.decode_token(token) - logger.debug("Successfully decoded token for roles") - except Exception as decode_error: - logger.warning(f"Failed to decode token for roles: {decode_error}") - # If we can't decode the token, try to get roles from introspection - try: - token_info = self.keycloak_openid.introspect(token) - logger.debug("Using introspection result for roles") - except Exception as introspect_error: - logger.error( - f"Failed to introspect token for roles: {introspect_error}" - ) - return [] - - # Extract roles from token info - realm_access: Dict[str, Any] = token_info.get("realm_access", {}) - roles = cast(List[str], realm_access.get("roles", [])) - - # Also check resource_access for client roles - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - # Combine realm and client roles - all_roles = list(set(roles + client_roles)) - logger.debug(f"Found roles: {all_roles}") - - return all_roles - except KeycloakError as e: - logger.error(f"Error getting user roles: {e}") - return [] - except Exception as e: - logger.error(f"Unexpected error getting user roles: {e}") - return [] - - def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: - """ - Get the organizations a user belongs to from their token. - This assumes that organization information is stored in the token - as client roles or in user attributes. - - Args: - token: The user's token - - Returns: - List of organization information - """ - import structlog - from django.conf import settings - - logger = structlog.getLogger(__name__) - - logger.debug(f"Getting organizations from token of length: {len(token)}") - - try: - # Decode the token to get user info - token_info = {} - try: - token_info = self.keycloak_openid.decode_token(token) - logger.debug("Successfully decoded token for organizations") - except Exception as decode_error: - logger.warning( - f"Failed to decode token for organizations: {decode_error}" - ) - # If we can't decode the token, try to get info from introspection - try: - token_info = self.keycloak_openid.introspect(token) - logger.debug("Using introspection result for organizations") - except Exception as introspect_error: - logger.error( - f"Failed to introspect token for organizations: {introspect_error}" - ) - return [] - - # Get organization info from resource_access or attributes - # This implementation depends on how organizations are represented in Keycloak - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - logger.debug(f"Found client roles: {client_roles}") - - # Extract organization info from roles - # Format could be 'org__' or similar - organizations = [] - for role in client_roles: - if role.startswith("org_"): - parts = role.split("_") - if len(parts) >= 3: - org_id = parts[1] - role_name = parts[2] - organizations.append( - {"organization_id": org_id, "role": role_name} - ) - - # If no organizations found through roles, check user attributes - if not organizations and token_info.get("attributes"): - attrs = token_info.get("attributes", {}) - org_attrs = attrs.get("organizations", []) - - if isinstance(org_attrs, str): - org_attrs = [org_attrs] # Convert single string to list - - for org_attr in org_attrs: - try: - # Format could be 'org_id:role' - org_id, role = org_attr.split(":") - organizations.append({"organization_id": org_id, "role": role}) - except ValueError: - # If no role specified, use default - organizations.append( - {"organization_id": org_attr, "role": "viewer"} - ) - - logger.debug(f"Found organizations: {organizations}") - return organizations - except KeycloakError as e: - logger.error(f"Error getting user organizations: {e}") - return [] - except Exception as e: - logger.error(f"Unexpected error getting user organizations: {e}") - return [] - - @transaction.atomic - def sync_user_from_keycloak( - self, - user_info: Dict[str, Any], - roles: List[str], - organizations: List[Dict[str, Any]], - ) -> Optional[User]: - """ - Synchronize user information from Keycloak to Django. - Creates or updates the User record and organization memberships. - - Args: - user_info: User information from Keycloak - roles: User roles from Keycloak (not used when maintaining roles in DB) - organizations: User organizations from Keycloak - - Returns: - The synchronized User object or None if failed - """ - import structlog - - logger = structlog.getLogger(__name__) - - # Log the user info we're trying to sync - logger.debug(f"Attempting to sync user with info: {user_info}") - - try: - # Extract key user information - keycloak_id = user_info.get("sub") - email = user_info.get("email") - username = user_info.get("preferred_username") or email - - # Validate required fields - if not keycloak_id or not username: - logger.error("Missing required user information from Keycloak") - return None - - # Initialize variables - user = None - created = False - - # Try to find user by keycloak_id first - try: - user = User.objects.get(keycloak_id=keycloak_id) - logger.debug(f"Found existing user by keycloak_id: {user.username}") - - # Update user details - user.username = str(username) if username else "" # type: ignore[assignment] - user.email = str(email) if email else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Try to find user by email - if email: - try: - user = User.objects.get(email=email) - logger.debug(f"Found existing user by email: {user.username}") - - # Update keycloak_id and other details - user.keycloak_id = str(keycloak_id) if keycloak_id else "" # type: ignore[assignment] - user.username = str(username) if username else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Try to find user by username - try: - user = User.objects.get(username=username) - logger.debug( - f"Found existing user by username: {user.username}" - ) - - # Update keycloak_id and other details - user.keycloak_id = str(keycloak_id) if keycloak_id else "" # type: ignore[assignment] - user.email = str(email) if email else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Create new user - logger.debug( - f"Creating new user with keycloak_id: {keycloak_id}" - ) - user = User.objects.create( - keycloak_id=str(keycloak_id) if keycloak_id else "", # type: ignore[arg-type] - username=str(username) if username else "", # type: ignore[arg-type] - email=str(email) if email else "", # type: ignore[arg-type] - first_name=( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ), - last_name=( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ), - is_active=True, - ) - created = True - - # If this is a new user, we'll keep default permissions - if created: - pass - - if user is not None: # Check that user is not None before saving - user.save() - - # If this is a new user and we want to sync organization memberships - # We'll only create new memberships for organizations found in Keycloak - # but we won't update existing memberships or remove any - if user is not None and created and organizations: - # Process organizations from Keycloak - only for new users - for org_info in organizations: - org_id: Optional[str] = org_info.get("organization_id") - if not org_id: - continue - - # Try to get the organization - try: - organization: Organization = Organization.objects.get(id=org_id) - - # For new users, assign the default viewer role - # The actual role management will be done in the application - default_role: Role = Role.objects.get(name="viewer") - - # Create the organization membership with default role - # Only if it doesn't already exist - OrganizationMembership.objects.get_or_create( - user=user, - organization=organization, - defaults={"role": default_role}, - ) - except Organization.DoesNotExist as e: - logger.error( - f"Error processing organization from Keycloak: {e}" - ) - except Role.DoesNotExist as e: - logger.error(f"Default viewer role not found: {e}") - - # We don't remove organization memberships that are no longer in Keycloak - # since we're maintaining roles in the database - - return user - except Exception as e: - logger.error(f"Error synchronizing user from Keycloak: {e}") - return None - - def update_user_in_keycloak(self, user: User) -> bool: - """Update user details in Keycloak using admin credentials.""" - if not user.keycloak_id: - logger.warning( - "Cannot update user in Keycloak: No keycloak_id", user_id=str(user.id) - ) - return False - - try: - # Get admin credentials from settings - admin_username = getattr(settings, "KEYCLOAK_ADMIN_USERNAME", "") - admin_password = getattr(settings, "KEYCLOAK_ADMIN_PASSWORD", "") - - # Log credential presence (not the actual values) - logger.info( - "Admin credentials check", - username_present=bool(admin_username), - password_present=bool(admin_password), - ) - - if not admin_username or not admin_password: - logger.error("Keycloak admin credentials not configured") - return False - - from keycloak import KeycloakOpenID - - # First get an admin token directly - keycloak_openid = KeycloakOpenID( - server_url=self.server_url, - client_id="admin-cli", # Special client for admin operations - realm_name="master", # Admin users are in master realm - verify=True, - ) - - # Get token - try: - token = keycloak_openid.token( - username=admin_username, - password=admin_password, - grant_type="password", - ) - access_token = token.get("access_token") - - if not access_token: - logger.error("Failed to get admin access token") - return False - - # Now use the token to update the user - import requests - - headers = { - "Authorization": f"Bearer {access_token}", - "Content-Type": "application/json", - } - - user_data = { - "firstName": user.first_name, - "lastName": user.last_name, - "email": user.email, - "emailVerified": True, - } - - # Direct API call to update user - base_url = self.server_url.rstrip("/") # Remove any trailing slash - response = requests.put( - f"{base_url}/admin/realms/{self.realm}/users/{user.keycloak_id}", - headers=headers, - json=user_data, - ) - - if response.status_code == 204: # Success for this endpoint - logger.info( - "Successfully updated user in Keycloak", - user_id=str(user.id), - keycloak_id=user.keycloak_id, - ) - return True - else: - logger.error( - f"Failed to update user in Keycloak: {response.status_code}: {response.text}", - user_id=str(user.id), - ) - return False - - except Exception as token_error: - logger.error( - f"Error getting admin token: {str(token_error)}", - user_id=str(user.id), - ) - return False - - except Exception as e: - logger.error( - f"Error updating user in Keycloak: {str(e)}", user_id=str(user.id) - ) - return False - - -# Create a singleton instance -keycloak_manager = KeycloakManager() +__all__ = ["KeycloakManager", "keycloak_manager"] diff --git a/authorization/middleware_utils.py b/authorization/middleware_utils.py index 8ec6e88..684f0a9 100644 --- a/authorization/middleware_utils.py +++ b/authorization/middleware_utils.py @@ -11,7 +11,7 @@ from rest_framework_simplejwt.tokens import AccessToken from api.utils.debug_utils import debug_auth_headers, debug_token_validation -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User logger = structlog.getLogger(__name__) diff --git a/authorization/schema/mutation.py b/authorization/schema/mutation.py index ec22951..fdf51f0 100644 --- a/authorization/schema/mutation.py +++ b/authorization/schema/mutation.py @@ -15,6 +15,7 @@ MutationResponse, ) from api.utils.graphql_telemetry import trace_resolver +from api.utils.keycloak_utils import keycloak_manager from authorization.models import OrganizationMembership, Role, User from authorization.permissions import IsAuthenticated from authorization.schema.inputs import ( @@ -34,12 +35,9 @@ @strawberry.type class Mutation: @strawberry.mutation(permission_classes=[IsAuthenticated]) - @trace_resolver( - name="update_user", attributes={"component": "user", "operation": "mutation"} - ) + @trace_resolver(name="update_user", attributes={"component": "user", "operation": "mutation"}) def update_user(self, info: Info, input: UpdateUserInput) -> TypeUser: """Update user details and sync with Keycloak.""" - from authorization.keycloak import keycloak_manager user = info.context.user @@ -139,9 +137,7 @@ def add_user_to_organization( user=user, organization=organization ) # If we get here, the membership exists - raise DjangoValidationError( - "User is already a member of this organization." - ) + raise DjangoValidationError("User is already a member of this organization.") except OrganizationMembership.DoesNotExist: # Membership doesn't exist, so create it membership = OrganizationMembership.objects.create( @@ -211,13 +207,9 @@ def remove_user_from_organization( organization = info.context.context.get("organization") # Check if the membership already exists - membership = OrganizationMembership.objects.get( - user=user, organization=organization - ) + membership = OrganizationMembership.objects.get(user=user, organization=organization) membership.delete() - return SuccessResponse( - success=True, message="User removed from organization" - ) + return SuccessResponse(success=True, message="User removed from organization") except User.DoesNotExist: raise DjangoValidationError(f"User with ID {input.user_id} does not exist.") except Role.DoesNotExist: @@ -270,8 +262,6 @@ def assign_dataset_permission( ) if result: - return SuccessResponse( - success=True, message="Permission assigned successfully" - ) + return SuccessResponse(success=True, message="Permission assigned successfully") else: return SuccessResponse(success=False, message="Failed to assign permission") diff --git a/authorization/views.py b/authorization/views.py index 3c5aab9..ff20f9a 100644 --- a/authorization/views.py +++ b/authorization/views.py @@ -6,8 +6,8 @@ from rest_framework.views import APIView from rest_framework_simplejwt.tokens import RefreshToken +from api.utils.keycloak_utils import keycloak_manager from authorization.consent import UserConsent -from authorization.keycloak import keycloak_manager from authorization.models import User from authorization.serializers import UserConsentSerializer from authorization.services import AuthorizationService @@ -62,9 +62,7 @@ def post(self, request: Request) -> Response: logger.info(f"Syncing user with Keycloak ID: {user_info.get('sub')}") user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) if not user: - logger.error( - f"Failed to sync user with Keycloak ID: {user_info.get('sub')}" - ) + logger.error(f"Failed to sync user with Keycloak ID: {user_info.get('sub')}") return Response( {"error": "Failed to synchronize user information"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -176,9 +174,7 @@ def put(self, request: Request) -> Response: user=request.user, activity_tracking_enabled=False # type: ignore[misc] ) - serializer = UserConsentSerializer( - consent, data=request.data, context={"request": request} - ) + serializer = UserConsentSerializer(consent, data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() return Response(serializer.data) From 18fe6166b4627d6b01f82917ab737f760bc068f9 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 12:10:25 +0530 Subject: [PATCH 019/127] Fix authentication to handle Django JWT tokens without decoding as Keycloak tokens --- authorization/authentication.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/authorization/authentication.py b/authorization/authentication.py index 93c2b5e..c6aef18 100644 --- a/authorization/authentication.py +++ b/authorization/authentication.py @@ -43,16 +43,35 @@ def authenticate(self, request: Request) -> Optional[Tuple[User, str]]: if not user_info.get("sub"): raise AuthenticationFailed("Token validation succeeded but missing subject ID") + # Check if this is a Django JWT (already validated user) or Keycloak token + # Django JWTs have 'user_id' in the payload, Keycloak tokens don't + from rest_framework_simplejwt.exceptions import InvalidToken, TokenError + from rest_framework_simplejwt.tokens import AccessToken + + try: + # Try to decode as Django JWT + access_token = AccessToken(token) # type: ignore[arg-type] + user_id = access_token.get("user_id") + + if user_id: + # This is a Django JWT - user is already synced, just get from DB + user = User.objects.get(id=user_id) + return (user, token) + except (TokenError, InvalidToken, User.DoesNotExist): + # Not a Django JWT or user not found, continue with Keycloak flow + pass + + # This is a Keycloak token - sync the user # Get user roles and organizations from the token roles = keycloak_manager.get_user_roles(token) organizations = keycloak_manager.get_user_organizations(token) # Sync the user information with our database - user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) - if not user: + synced_user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) + if not synced_user: raise AuthenticationFailed("Failed to synchronize user information") - return (user, token) + return (synced_user, token) def authenticate_header(self, request: Request) -> str: """ From b0d2e1ba16368d43ce5aa0dcc3f3f51ef1c742f0 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 12:26:26 +0530 Subject: [PATCH 020/127] add ability to login as service account --- dataspace_sdk/auth.py | 91 +++++++++++++++++++++++++++++++++++++++++ dataspace_sdk/client.py | 28 +++++++++++++ 2 files changed, 119 insertions(+) diff --git a/dataspace_sdk/auth.py b/dataspace_sdk/auth.py index c04e162..4ecac6e 100644 --- a/dataspace_sdk/auth.py +++ b/dataspace_sdk/auth.py @@ -77,6 +77,39 @@ def login(self, username: str, password: str) -> Dict[str, Any]: # Login to DataSpace backend return self._login_with_keycloak_token(keycloak_token) + def login_as_service_account(self) -> Dict[str, Any]: + """ + Login using client credentials (service account). + + This method authenticates the client itself (not a user) using + the client_id and client_secret. Requires the Keycloak client + to have "Service Accounts Enabled". + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + """ + if not all( + [ + self.keycloak_url, + self.keycloak_realm, + self.keycloak_client_id, + self.keycloak_client_secret, + ] + ): + raise DataSpaceAuthError( + "Service account authentication requires keycloak_url, " + "keycloak_realm, keycloak_client_id, and keycloak_client_secret." + ) + + # Get Keycloak token using client credentials + keycloak_token = self._get_service_account_token() + + # Login to DataSpace backend + return self._login_with_keycloak_token(keycloak_token) + def _get_keycloak_token(self, username: str, password: str) -> str: """ Get Keycloak access token using username and password. @@ -140,6 +173,64 @@ def _get_keycloak_token(self, username: str, password: str) -> str: except requests.RequestException as e: raise DataSpaceAuthError(f"Network error during Keycloak authentication: {str(e)}") + def _get_service_account_token(self) -> str: + """ + Get Keycloak access token using client credentials (service account). + + Returns: + Keycloak access token + + Raises: + DataSpaceAuthError: If authentication fails + """ + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "client_credentials", + "client_id": self.keycloak_client_id, + "client_secret": self.keycloak_client_secret, + } + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + # Calculate token expiration time + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in Keycloak response") + + return self.keycloak_access_token + else: + error_data = response.json() + error_msg = error_data.get( + "error_description", + error_data.get("error", "Service account authentication failed"), + ) + raise DataSpaceAuthError( + f"Service account login failed: {error_msg}. " + f"Ensure 'Service Accounts Enabled' is ON in Keycloak client settings.", + status_code=response.status_code, + response=error_data, + ) + except requests.RequestException as e: + raise DataSpaceAuthError( + f"Network error during service account authentication: {str(e)}" + ) + def _refresh_keycloak_token(self) -> str: """ Refresh Keycloak access token using refresh token. diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index 8103a91..a2c3674 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -91,6 +91,34 @@ def login(self, username: str, password: str) -> dict: """ return self._auth.login(username, password) + def login_as_service_account(self) -> dict: + """ + Login using client credentials (service account). + + This method authenticates the client itself using client_id and client_secret. + The Keycloak client must have "Service Accounts Enabled" turned ON. + + This is the recommended approach for backend services and automated tasks. + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + + Example: + >>> client = DataSpaceClient( + ... base_url="https://api.dataspace.example.com", + ... keycloak_url="https://opub-kc.civicdatalab.in", + ... keycloak_realm="DataSpace", + ... keycloak_client_id="dataspace", + ... keycloak_client_secret="your-secret" + ... ) + >>> info = client.login_as_service_account() + >>> print("Authenticated as service account") + """ + return self._auth.login_as_service_account() + def login_with_token(self, keycloak_token: str) -> dict: """ Login using a pre-obtained Keycloak token. From 682b82b020b9d1d8a964856c2742c202aa37c2be Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 12:30:37 +0530 Subject: [PATCH 021/127] Fix AI model URL bug and update SDK documentation - Fix aimodels.get_by_id() to use super().get() to avoid recursion - Update README.md with new authentication methods - Document AI model calling functionality - Add comprehensive API reference - Update Quick Start examples with Keycloak configuration --- dataspace_sdk/resources/aimodels.py | 5 +- dataspace_sdk/resources/datasets.py | 2 +- dataspace_sdk/resources/usecases.py | 2 +- docs/sdk/README.md | 113 +++++++++++++++++++++++----- 4 files changed, 101 insertions(+), 21 deletions(-) diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 0b0d56e..84c4262 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -61,7 +61,7 @@ def search( if sort: params["sort"] = sort - return self.get("/api/search/aimodel/", params=params) + return super().get("/api/search/aimodel/", params=params) def get_by_id(self, model_id: str) -> Dict[str, Any]: """ @@ -73,7 +73,8 @@ def get_by_id(self, model_id: str) -> Dict[str, Any]: Returns: Dictionary containing AI model information """ - return self.get(f"/api/aimodels/{model_id}/") + # Use parent class get method with full endpoint path + return super().get(f"/api/aimodels/{model_id}/") def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: """ diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index bc655e9..cad599b 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -57,7 +57,7 @@ def search( if sort: params["sort"] = sort - return self.get("/api/search/dataset/", params=params) + return super().get("/api/search/dataset/", params=params) def get_by_id(self, dataset_id: str) -> Dict[str, Any]: """ diff --git a/dataspace_sdk/resources/usecases.py b/dataspace_sdk/resources/usecases.py index 3d75a2f..6aca2cd 100644 --- a/dataspace_sdk/resources/usecases.py +++ b/dataspace_sdk/resources/usecases.py @@ -57,7 +57,7 @@ def search( if sort: params["sort"] = sort - return self.get("/api/search/usecase/", params=params) + return super().get("/api/search/usecase/", params=params) def get_by_id(self, usecase_id: int) -> Dict[str, Any]: """ diff --git a/docs/sdk/README.md b/docs/sdk/README.md index 8a251cf..19a6747 100644 --- a/docs/sdk/README.md +++ b/docs/sdk/README.md @@ -29,11 +29,20 @@ pip install -e ".[dev]" ```python from dataspace_sdk import DataSpaceClient -# Initialize the client -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +# Initialize the client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dev.api.civicdataspace.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + keycloak_client_secret="your_client_secret" +) -# Login with Keycloak token -user_info = client.login(keycloak_token="your_keycloak_token") +# Login with username and password +user_info = client.login( + username="your_email@example.com", + password="your_password" +) print(f"Logged in as: {user_info['user']['username']}") # Search for datasets @@ -49,36 +58,66 @@ dataset = client.datasets.get_by_id("dataset-uuid") print(f"Dataset: {dataset['title']}") # Get organization's resources -org_id = user_info['user']['organizations'][0]['id'] -org_datasets = client.datasets.get_organization_datasets(org_id) +if user_info['user']['organizations']: + org_id = user_info['user']['organizations'][0]['id'] + org_datasets = client.datasets.get_organization_datasets(org_id) ``` ## Features -- **Authentication**: Login with Keycloak tokens and automatic token refresh +- **Authentication**: Multiple authentication methods (username/password, Keycloak token, service account) +- **Automatic Token Management**: Automatic token refresh and re-login - **Datasets**: Search, retrieve, and list datasets with filtering and pagination -- **AI Models**: Search, retrieve, and list AI models with filtering +- **AI Models**: Search, retrieve, call, and list AI models with filtering - **Use Cases**: Search, retrieve, and list use cases with filtering - **Organization Resources**: Get resources specific to your organizations - **GraphQL & REST**: Supports both GraphQL and REST API endpoints +- **Error Handling**: Comprehensive exception handling with detailed error messages ## Authentication -### Login with Keycloak +The SDK supports three authentication methods: + +### 1. Username and Password (Recommended for Users) ```python from dataspace_sdk import DataSpaceClient -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +client = DataSpaceClient( + base_url="https://dev.api.civicdataspace.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + keycloak_client_secret="your_client_secret" +) -# Login with Keycloak token -response = client.login(keycloak_token="your_keycloak_token") +# Login with username and password +user_info = client.login( + username="your_email@example.com", + password="your_password" +) # Access user information -print(response['user']['username']) -print(response['user']['organizations']) +print(user_info['user']['username']) +print(user_info['user']['organizations']) ``` +### 2. Keycloak Token (For Token Pass-through) + +```python +# Login with an existing Keycloak token +response = client.login_with_token(keycloak_token="your_keycloak_token") +``` + +### 3. Service Account (For Backend Services) + +```python +# Login as a service account using client credentials +service_info = client.login_as_service_account() +``` + +For detailed authentication documentation, see [AUTHENTICATION_COMPLETE.md](./AUTHENTICATION_COMPLETE.md) + ### Token Refresh ```python @@ -203,6 +242,36 @@ print(f"Provider: {model['provider']}") print(f"Endpoints: {len(model['endpoints'])}") ``` +### Call an AI Model + +```python +# Call an AI model with input text +result = client.aimodels.call_model( + model_id="model-uuid", + input_text="What is the capital of France?", + parameters={ + "temperature": 0.7, + "max_tokens": 100 + } +) + +if result['success']: + print(f"Output: {result['output']}") + print(f"Latency: {result['latency_ms']}ms") + print(f"Provider: {result['provider']}") +else: + print(f"Error: {result['error']}") + +# For long-running operations, use async call +task = client.aimodels.call_model_async( + model_id="model-uuid", + input_text="Generate a long document...", + parameters={"max_tokens": 2000} +) +print(f"Task ID: {task['task_id']}") +print(f"Status: {task['status']}") +``` + ### List All AI Models ```python @@ -249,6 +318,7 @@ results = client.usecases.search( ### Get Use Case by ID ```python +# Get use case by ID usecase = client.usecases.get_by_id(123) print(f"Title: {usecase['title']}") @@ -375,7 +445,9 @@ Main client for interacting with DataSpace API. **Methods:** -- `login(keycloak_token: str) -> dict`: Login with Keycloak token +- `login(username: str, password: str) -> dict`: Login with username and password +- `login_with_token(keycloak_token: str) -> dict`: Login with Keycloak token +- `login_as_service_account() -> dict`: Login as service account (client credentials) - `refresh_token() -> str`: Refresh access token - `get_user_info() -> dict`: Get current user information - `is_authenticated() -> bool`: Check authentication status @@ -395,10 +467,12 @@ Client for dataset operations. **Methods:** - `search(...)`: Search datasets with filters -- `get_by_id(dataset_id: str)`: Get dataset by UUID +- `get_by_id(dataset_id: str)`: Get dataset by UUID (GraphQL) - `list_all(...)`: List all datasets with pagination - `get_trending(limit: int)`: Get trending datasets - `get_organization_datasets(organization_id: str, ...)`: Get organization's datasets +- `get_resources(dataset_id: str)`: Get dataset resources +- `list_by_organization(organization_id: str, ...)`: List datasets by organization ### AIModelClient @@ -409,8 +483,13 @@ Client for AI model operations. - `search(...)`: Search AI models with filters - `get_by_id(model_id: str)`: Get AI model by UUID (REST) - `get_by_id_graphql(model_id: str)`: Get AI model by UUID (GraphQL) +- `call_model(model_id: str, input_text: str, parameters: dict)`: Call an AI model +- `call_model_async(model_id: str, input_text: str, parameters: dict)`: Call an AI model asynchronously - `list_all(...)`: List all AI models with pagination - `get_organization_models(organization_id: str, ...)`: Get organization's AI models +- `create(data: dict)`: Create a new AI model +- `update(model_id: str, data: dict)`: Update an AI model +- `delete_model(model_id: str)`: Delete an AI model ### UseCaseClient @@ -419,7 +498,7 @@ Client for use case operations. **Methods:** - `search(...)`: Search use cases with filters -- `get_by_id(usecase_id: int)`: Get use case by ID +- `get_by_id(usecase_id: int)`: Get use case by ID (GraphQL) - `list_all(...)`: List all use cases with pagination - `get_organization_usecases(organization_id: str, ...)`: Get organization's use cases From d4262770b4d340dd20426facf31ebd459ad451e2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 27 Nov 2025 07:15:57 +0000 Subject: [PATCH 022/127] Bump SDK version to 0.3.2 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index c79f072..5355cb4 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.3.1" +__version__ = "0.3.2" diff --git a/pyproject.toml b/pyproject.toml index 1e72842..c8338b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.3.1" +version = "0.3.2" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.3.1" +python_version = "0.3.2" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From bacdcebae98ebf4c2a1216deb27d35f0698a431e Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 13:13:33 +0530 Subject: [PATCH 023/127] add few sdk tests --- tests/test_aimodels.py | 106 +++++++++++++++++++++++++++++++++++++++++ tests/test_auth.py | 45 +++++++++++++++++ tests/test_datasets.py | 45 +++++++++++++++++ 3 files changed, 196 insertions(+) diff --git a/tests/test_aimodels.py b/tests/test_aimodels.py index a184f3e..2629dc3 100644 --- a/tests/test_aimodels.py +++ b/tests/test_aimodels.py @@ -117,6 +117,112 @@ def test_graphql_error_handling(self, mock_post: MagicMock) -> None: with self.assertRaises(DataSpaceAPIError): self.client.get_by_id_graphql("123") + @patch.object(AIModelClient, "post") + def test_call_model(self, mock_post: MagicMock) -> None: + """Test calling an AI model.""" + mock_post.return_value = { + "success": True, + "output": "Paris is the capital of France.", + "latency_ms": 150, + "provider": "OpenAI", + } + + result = self.client.call_model( + model_id="123", + input_text="What is the capital of France?", + parameters={"temperature": 0.7, "max_tokens": 100}, + ) + + self.assertTrue(result["success"]) + self.assertEqual(result["output"], "Paris is the capital of France.") + self.assertEqual(result["latency_ms"], 150) + mock_post.assert_called_once() + + @patch.object(AIModelClient, "post") + def test_call_model_async(self, mock_post: MagicMock) -> None: + """Test calling an AI model asynchronously.""" + mock_post.return_value = { + "task_id": "task-456", + "status": "PENDING", + "created_at": "2024-01-01T00:00:00Z", + } + + result = self.client.call_model_async( + model_id="123", + input_text="Generate a long document", + parameters={"max_tokens": 2000}, + ) + + self.assertEqual(result["task_id"], "task-456") + self.assertEqual(result["status"], "PENDING") + mock_post.assert_called_once() + + @patch.object(AIModelClient, "post") + def test_call_model_error(self, mock_post: MagicMock) -> None: + """Test AI model call with error.""" + mock_post.return_value = { + "success": False, + "error": "Model not available", + } + + result = self.client.call_model( + model_id="123", + input_text="Test input", + ) + + self.assertFalse(result["success"]) + self.assertEqual(result["error"], "Model not available") + + @patch.object(AIModelClient, "post") + def test_create_model(self, mock_post: MagicMock) -> None: + """Test creating an AI model.""" + mock_post.return_value = { + "id": "new-model-123", + "displayName": "New Model", + "modelType": "LLM", + } + + result = self.client.create( + { + "displayName": "New Model", + "modelType": "LLM", + "provider": "OpenAI", + } + ) + + self.assertEqual(result["id"], "new-model-123") + self.assertEqual(result["displayName"], "New Model") + mock_post.assert_called_once() + + @patch.object(AIModelClient, "put") + def test_update_model(self, mock_put: MagicMock) -> None: + """Test updating an AI model.""" + mock_put.return_value = { + "id": "123", + "displayName": "Updated Model", + "modelType": "LLM", + } + + result = self.client.update( + "123", + { + "displayName": "Updated Model", + }, + ) + + self.assertEqual(result["displayName"], "Updated Model") + mock_put.assert_called_once() + + @patch.object(AIModelClient, "delete") + def test_delete_model(self, mock_delete: MagicMock) -> None: + """Test deleting an AI model.""" + mock_delete.return_value = {"message": "Model deleted successfully"} + + result = self.client.delete_model("123") + + self.assertEqual(result["message"], "Model deleted successfully") + mock_delete.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_auth.py b/tests/test_auth.py index 5a6676a..cf0212c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -141,6 +141,51 @@ def test_is_authenticated(self) -> None: self.auth_client.access_token = "test_token" self.assertTrue(self.auth_client.is_authenticated()) + @patch("dataspace_sdk.auth.requests.post") + def test_login_as_service_account(self, mock_post: MagicMock) -> None: + """Test successful service account login.""" + # Create auth client with client secret + auth_client = AuthClient( + self.base_url, + keycloak_url=self.keycloak_url, + keycloak_realm=self.keycloak_realm, + keycloak_client_id=self.keycloak_client_id, + keycloak_client_secret="test_secret", + ) + + # Mock Keycloak token response + keycloak_response = MagicMock() + keycloak_response.status_code = 200 + keycloak_response.json.return_value = { + "access_token": "service_access_token", + "refresh_token": "service_refresh_token", + "expires_in": 300, + } + + # Mock DataSpace backend login response + backend_response = MagicMock() + backend_response.status_code = 200 + backend_response.json.return_value = { + "access": "test_access_token", + "refresh": "test_refresh_token", + "user": {"id": "service-123", "username": "service-account"}, + } + + mock_post.side_effect = [keycloak_response, backend_response] + + result = auth_client.login_as_service_account() + + self.assertEqual(auth_client.access_token, "test_access_token") + self.assertEqual(auth_client.refresh_token, "test_refresh_token") + self.assertIsNotNone(auth_client.user_info) + self.assertEqual(result["user"]["username"], "service-account") + self.assertEqual(mock_post.call_count, 2) + + def test_login_as_service_account_no_secret(self) -> None: + """Test service account login without client secret.""" + with self.assertRaises(DataSpaceAuthError): + self.auth_client.login_as_service_account() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4219cb4..7ad2d6a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -89,6 +89,51 @@ def test_search_with_filters(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 5) mock_get.assert_called_once() + @patch.object(DatasetClient, "get") + def test_get_resources(self, mock_get: MagicMock) -> None: + """Test get dataset resources.""" + mock_get.return_value = [ + { + "id": "res-1", + "title": "Resource 1", + "format": "CSV", + "url": "https://example.com/data.csv", + } + ] + + result = self.client.get_resources("dataset-123") + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["title"], "Resource 1") + mock_get.assert_called_once() + + @patch.object(DatasetClient, "post") + def test_list_by_organization(self, mock_post: MagicMock) -> None: + """Test list datasets by organization.""" + mock_post.return_value = { + "data": { + "datasets": [ + {"id": "1", "title": "Org Dataset 1"}, + {"id": "2", "title": "Org Dataset 2"}, + ] + } + } + + result = self.client.list_by_organization("org-123", limit=10) + + self.assertIsInstance(result, (list, dict)) + mock_post.assert_called_once() + + @patch.object(DatasetClient, "get") + def test_search_with_sorting(self, mock_get: MagicMock) -> None: + """Test dataset search with sorting.""" + mock_get.return_value = {"total": 3, "results": []} + + result = self.client.search(query="test", sort="recent", page=1, page_size=10) + + self.assertEqual(result["total"], 3) + mock_get.assert_called_once() + if __name__ == "__main__": unittest.main() From c5cf62d26bdf90ccf236aaa8efaea4d9658281a0 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 15:38:54 +0530 Subject: [PATCH 024/127] dont sync organizations form keycloak --- api/utils/keycloak_utils.py | 39 +++---------------------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 6bdb335..a305759 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -361,16 +361,16 @@ def sync_user_from_keycloak( self, user_info: Dict[str, Any], roles: List[str], - organizations: List[Dict[str, Any]], + organizations: Optional[List[Dict[str, Any]]] = None, ) -> Optional[User]: """ Synchronize user information from Keycloak to Django. - Creates or updates the User and UserOrganization records. + Creates or updates the User record. Args: user_info: User information from Keycloak roles: User roles from Keycloak - organizations: User organization memberships from Keycloak + organizations: Deprecated - organizations are managed in DataSpace Returns: The synchronized User object or None if failed @@ -406,39 +406,6 @@ def sync_user_from_keycloak( user.save() - # Update organization memberships - # First, get all existing organization memberships - existing_memberships = OrganizationMembership.objects.filter(user=user) - existing_org_ids = { - membership.organization_id for membership in existing_memberships # type: ignore[attr-defined] - } - - # Process organizations from Keycloak - for org_info in organizations: - org_id = org_info.get("organization_id") - role = org_info.get("role", "viewer") # Default to viewer if role not specified - - # Try to get the organization - try: - organization = Organization.objects.get(id=org_id) # type: ignore[misc] - - # Create or update the membership - OrganizationMembership.objects.update_or_create( - user=user, organization=organization, defaults={"role": role} - ) - - # Remove from the set of existing memberships - if org_id in existing_org_ids: - existing_org_ids.remove(org_id) - except Organization.DoesNotExist: - logger.warning(f"Organization with ID {org_id} does not exist") - - # Remove memberships that no longer exist in Keycloak - if existing_org_ids: - OrganizationMembership.objects.filter( - user=user, organization_id__in=existing_org_ids - ).delete() - return user except Exception as e: logger.error(f"Error synchronizing user from Keycloak: {e}") From 6c69ce35ef8f505a417148bd41c3b7c85b219eec Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 16:21:41 +0530 Subject: [PATCH 025/127] return empty org at the time of authentication --- api/utils/keycloak_utils.py | 86 +++++++++++++++---------------------- authorization/services.py | 21 ++++----- 2 files changed, 43 insertions(+), 64 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index a305759..15f7f06 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -79,6 +79,13 @@ def validate_token(self, token: str) -> Dict[str, Any]: from authorization.models import User user = User.objects.get(id=user_id) + user.save() + + # NOTE: Organizations are managed in DataSpace database, not Keycloak + # Organization memberships should be created/managed through DataSpace's + # organization management interface, not during Keycloak sync + + # Log for debugging Keycloak validation") # Return user info in Keycloak format return { "sub": ( @@ -167,31 +174,18 @@ def get_user_organizations_from_token_info(self, token_info: dict) -> List[Dict[ """ Get organizations from token introspection data. + NOTE: In DataSpace, organizations are ALWAYS managed in the database. + This method always returns an empty list. + Args: - token_info: Token introspection response + token_info: Token introspection response (not used) Returns: - List of organization information + Empty list - organizations are managed in DataSpace database """ - try: - # Get organization info from resource_access or attributes - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - # Extract organization info from roles - organizations = [] - for role in client_roles: - if role.startswith("org_"): - parts = role.split("_") - if len(parts) >= 3: - org_id = parts[1] - role_name = parts[2] - organizations.append({"organization_id": org_id, "role": role_name}) - - return organizations - except Exception as e: - logger.error(f"Error getting user organizations: {e}") - return [] + # Organizations are managed in DataSpace database, not Keycloak + logger.debug("Organizations are managed in DataSpace DB, returning empty list") + return [] def get_user_roles(self, token: str) -> list[str]: """ @@ -229,40 +223,24 @@ def get_user_roles(self, token: str) -> list[str]: def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: """ Get the organizations a user belongs to from their token. - This assumes that organization information is stored in the token - as client roles or in user attributes. + + NOTE: In DataSpace, organizations and memberships are ALWAYS managed + in the database, NOT in Keycloak. This method always returns an empty list. + + Organizations should be retrieved using AuthorizationService.get_user_organizations() + which queries the OrganizationMembership table. Args: - token: The user's token + token: The user's token (not used) Returns: - List of organization information + Empty list - organizations are managed in DataSpace database """ - try: - # Decode the token to get user info - token_info = self.keycloak_openid.decode_token(token) - - # Get organization info from resource_access or attributes - # This implementation depends on how organizations are represented in Keycloak - # This is a simplified example - adjust based on your Keycloak configuration - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - # Extract organization info from roles - # Format could be 'org__' or similar - organizations = [] - for role in client_roles: - if role.startswith("org_"): - parts = role.split("_") - if len(parts) >= 3: - org_id = parts[1] - role_name = parts[2] - organizations.append({"organization_id": org_id, "role": role_name}) - - return organizations - except KeycloakError as e: - logger.error(f"Error getting user organizations: {e}") - return [] + # Organizations are managed in DataSpace database, not Keycloak + # Always return empty list - the actual organizations will be fetched + # from the database via AuthorizationService.get_user_organizations() + logger.debug("Organizations are managed in DataSpace DB, returning empty list") + return [] def update_user_in_keycloak(self, user: User) -> bool: """Update user details in Keycloak using admin credentials.""" @@ -367,10 +345,14 @@ def sync_user_from_keycloak( Synchronize user information from Keycloak to Django. Creates or updates the User record. + NOTE: Organizations are ALWAYS managed in DataSpace database, not Keycloak. + Organization memberships should be created/managed through DataSpace's + organization management interface. This method does NOT sync organizations. + Args: user_info: User information from Keycloak - roles: User roles from Keycloak - organizations: Deprecated - organizations are managed in DataSpace + roles: User roles from Keycloak (for is_staff/is_superuser only) + organizations: Ignored - organizations are managed in DataSpace database Returns: The synchronized User object or None if failed diff --git a/authorization/services.py b/authorization/services.py index 7c3e8fb..7ed117f 100644 --- a/authorization/services.py +++ b/authorization/services.py @@ -27,11 +27,14 @@ def get_user_organizations(user_id: int) -> List[Dict[str, Any]]: List of dictionaries containing organization info and user's role """ # Use explicit type annotation for the queryset result - memberships = OrganizationMembership.objects.filter( - user_id=user_id - ).select_related( + memberships = OrganizationMembership.objects.filter(user_id=user_id).select_related( "organization", "role" ) # type: ignore[attr-defined] + + logger.info( + f"Getting organizations for user_id={user_id}, found {memberships.count()} memberships" + ) + return [ { "id": membership.organization.id, # type: ignore[attr-defined] @@ -59,9 +62,7 @@ def get_user_datasets(user_id: int) -> List[Dict[str, Any]]: List of dictionaries containing dataset info and user's role """ # Use explicit type annotation for the queryset result - dataset_permissions = DatasetPermission.objects.filter( - user_id=user_id - ).select_related( + dataset_permissions = DatasetPermission.objects.filter(user_id=user_id).select_related( "dataset", "role" ) # type: ignore[attr-defined] return [ @@ -80,9 +81,7 @@ def get_user_datasets(user_id: int) -> List[Dict[str, Any]]: ] @staticmethod - def check_organization_permission( - user_id: int, organization_id: int, operation: str - ) -> bool: + def check_organization_permission(user_id: int, organization_id: int, operation: str) -> bool: """ Check if a user has permission to perform an operation on an organization. @@ -120,9 +119,7 @@ def check_organization_permission( return False @staticmethod - def check_dataset_permission( - user_id: int, dataset_id: Union[int, str], operation: str - ) -> bool: + def check_dataset_permission(user_id: int, dataset_id: Union[int, str], operation: str) -> bool: """ Check if a user has permission to perform an operation on a dataset. Checks both organization-level and dataset-specific permissions. From a229a5e550e55294cb17f9b68cd584341357c13d Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 18:25:56 +0530 Subject: [PATCH 026/127] fix tests --- tests/test_aimodels.py | 83 +++++++++++++++++------------------- tests/test_datasets.py | 97 ++++++++++++++++++++++-------------------- tests/test_usecases.py | 38 ++++++++--------- 3 files changed, 108 insertions(+), 110 deletions(-) diff --git a/tests/test_aimodels.py b/tests/test_aimodels.py index 2629dc3..eb6175e 100644 --- a/tests/test_aimodels.py +++ b/tests/test_aimodels.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(AIModelClient, "get") - def test_search_models(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_search_models(self, mock_request: MagicMock) -> None: """Test AI model search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 5, "results": [{"id": "1", "displayName": "Test Model", "modelType": "LLM"}], } @@ -33,12 +33,12 @@ def test_search_models(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 5) self.assertEqual(len(result["results"]), 1) self.assertEqual(result["results"][0]["displayName"], "Test Model") - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "get") - def test_get_model_by_id(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_get_model_by_id(self, mock_request: MagicMock) -> None: """Test get AI model by ID.""" - mock_get.return_value = { + mock_request.return_value = { "id": "123", "displayName": "Test Model", "modelType": "LLM", @@ -90,37 +90,37 @@ def test_get_organization_models(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(AIModelClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test AI model search with filters.""" - mock_get.return_value = {"total": 3, "results": []} + mock_request.return_value = {"total": 3, "results": []} result = self.client.search( query="language", tags=["nlp"], sectors=["tech"], + status="ACTIVE", model_type="LLM", provider="OpenAI", - status="ACTIVE", ) self.assertEqual(result["total"], 3) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "post") - def test_graphql_error_handling(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_graphql_error_handling(self, mock_request: MagicMock) -> None: """Test GraphQL error handling.""" from dataspace_sdk.exceptions import DataSpaceAPIError - mock_post.return_value = {"errors": [{"message": "GraphQL error"}]} + mock_request.return_value = {"errors": [{"message": "GraphQL error"}]} with self.assertRaises(DataSpaceAPIError): self.client.get_by_id_graphql("123") - @patch.object(AIModelClient, "post") - def test_call_model(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_call_model(self, mock_request: MagicMock) -> None: """Test calling an AI model.""" - mock_post.return_value = { + mock_request.return_value = { "success": True, "output": "Paris is the capital of France.", "latency_ms": 150, @@ -136,12 +136,12 @@ def test_call_model(self, mock_post: MagicMock) -> None: self.assertTrue(result["success"]) self.assertEqual(result["output"], "Paris is the capital of France.") self.assertEqual(result["latency_ms"], 150) - mock_post.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "post") - def test_call_model_async(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_call_model_async(self, mock_request: MagicMock) -> None: """Test calling an AI model asynchronously.""" - mock_post.return_value = { + mock_request.return_value = { "task_id": "task-456", "status": "PENDING", "created_at": "2024-01-01T00:00:00Z", @@ -155,12 +155,12 @@ def test_call_model_async(self, mock_post: MagicMock) -> None: self.assertEqual(result["task_id"], "task-456") self.assertEqual(result["status"], "PENDING") - mock_post.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "post") - def test_call_model_error(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_call_model_error(self, mock_request: MagicMock) -> None: """Test AI model call with error.""" - mock_post.return_value = { + mock_request.return_value = { "success": False, "error": "Model not available", } @@ -173,10 +173,10 @@ def test_call_model_error(self, mock_post: MagicMock) -> None: self.assertFalse(result["success"]) self.assertEqual(result["error"], "Model not available") - @patch.object(AIModelClient, "post") - def test_create_model(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_create_model(self, mock_request: MagicMock) -> None: """Test creating an AI model.""" - mock_post.return_value = { + mock_request.return_value = { "id": "new-model-123", "displayName": "New Model", "modelType": "LLM", @@ -192,36 +192,33 @@ def test_create_model(self, mock_post: MagicMock) -> None: self.assertEqual(result["id"], "new-model-123") self.assertEqual(result["displayName"], "New Model") - mock_post.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "put") - def test_update_model(self, mock_put: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_update_model(self, mock_request: MagicMock) -> None: """Test updating an AI model.""" - mock_put.return_value = { + mock_request.return_value = { "id": "123", "displayName": "Updated Model", - "modelType": "LLM", + "description": "Updated description", } result = self.client.update( - "123", - { - "displayName": "Updated Model", - }, + "123", {"displayName": "Updated Model", "description": "Updated description"} ) self.assertEqual(result["displayName"], "Updated Model") - mock_put.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "delete") - def test_delete_model(self, mock_delete: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_delete_model(self, mock_request: MagicMock) -> None: """Test deleting an AI model.""" - mock_delete.return_value = {"message": "Model deleted successfully"} + mock_request.return_value = {"message": "Model deleted successfully"} result = self.client.delete_model("123") self.assertEqual(result["message"], "Model deleted successfully") - mock_delete.assert_called_once() + mock_request.assert_called_once() if __name__ == "__main__": diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7ad2d6a..29ec3f8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(DatasetClient, "get") - def test_search_datasets(self, mock_get: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_search_datasets(self, mock_request: MagicMock) -> None: """Test dataset search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 10, "results": [{"id": "1", "title": "Test Dataset"}], } @@ -32,22 +32,22 @@ def test_search_datasets(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 10) self.assertEqual(len(result["results"]), 1) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(DatasetClient, "post") - def test_get_dataset_by_id(self, mock_post: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_get_dataset_by_id(self, mock_request: MagicMock) -> None: """Test get dataset by ID.""" - mock_post.return_value = {"data": {"dataset": {"id": "123", "title": "Test Dataset"}}} + mock_request.return_value = {"data": {"dataset": {"id": "123", "title": "Test Dataset"}}} result = self.client.get_by_id("123") self.assertEqual(result["id"], "123") self.assertEqual(result["title"], "Test Dataset") - @patch.object(DatasetClient, "post") - def test_list_all_datasets(self, mock_post: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_list_all_datasets(self, mock_request: MagicMock) -> None: """Test list all datasets.""" - mock_post.return_value = {"data": {"datasets": [{"id": "1", "title": "Dataset 1"}]}} + mock_request.return_value = {"data": {"datasets": [{"id": "1", "title": "Dataset 1"}]}} result = self.client.list_all(limit=10, offset=0) @@ -73,10 +73,10 @@ def test_get_organization_datasets(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(DatasetClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test dataset search with filters.""" - mock_get.return_value = {"total": 5, "results": []} + mock_request.return_value = {"total": 5, "results": []} result = self.client.search( query="health", @@ -87,52 +87,57 @@ def test_search_with_filters(self, mock_get: MagicMock) -> None: ) self.assertEqual(result["total"], 5) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(DatasetClient, "get") - def test_get_resources(self, mock_get: MagicMock) -> None: - """Test get dataset resources.""" - mock_get.return_value = [ - { - "id": "res-1", - "title": "Resource 1", - "format": "CSV", - "url": "https://example.com/data.csv", + @patch.object(DatasetClient, "_make_request") + def test_get_dataset_with_resources(self, mock_request: MagicMock) -> None: + """Test get dataset by ID which includes resources.""" + mock_request.return_value = { + "data": { + "dataset": { + "id": "dataset-123", + "title": "Test Dataset", + "resources": [ + { + "id": "res-1", + "title": "Resource 1", + "fileDetails": {"format": "CSV"}, + } + ], + } } - ] + } - result = self.client.get_resources("dataset-123") + result = self.client.get_by_id("dataset-123") - self.assertEqual(len(result), 1) - self.assertEqual(result[0]["title"], "Resource 1") - mock_get.assert_called_once() + self.assertEqual(result["id"], "dataset-123") + self.assertEqual(len(result["resources"]), 1) + self.assertEqual(result["resources"][0]["title"], "Resource 1") + mock_request.assert_called_once() - @patch.object(DatasetClient, "post") - def test_list_by_organization(self, mock_post: MagicMock) -> None: - """Test list datasets by organization.""" - mock_post.return_value = { - "data": { - "datasets": [ - {"id": "1", "title": "Org Dataset 1"}, - {"id": "2", "title": "Org Dataset 2"}, - ] - } - } + @patch.object(DatasetClient, "list_all") + def test_get_organization_datasets(self, mock_list_all: MagicMock) -> None: + """Test get datasets by organization.""" + mock_list_all.return_value = [ + {"id": "1", "title": "Org Dataset 1"}, + {"id": "2", "title": "Org Dataset 2"}, + ] - result = self.client.list_by_organization("org-123", limit=10) + result = self.client.get_organization_datasets("org-123", limit=10) - self.assertIsInstance(result, (list, dict)) - mock_post.assert_called_once() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + mock_list_all.assert_called_once_with(organization_id="org-123", limit=10, offset=0) - @patch.object(DatasetClient, "get") - def test_search_with_sorting(self, mock_get: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_search_with_sorting(self, mock_request: MagicMock) -> None: """Test dataset search with sorting.""" - mock_get.return_value = {"total": 3, "results": []} + mock_request.return_value = {"total": 3, "results": []} result = self.client.search(query="test", sort="recent", page=1, page_size=10) self.assertEqual(result["total"], 3) - mock_get.assert_called_once() + mock_request.assert_called_once() if __name__ == "__main__": diff --git a/tests/test_usecases.py b/tests/test_usecases.py index 6d966af..02d75f3 100644 --- a/tests/test_usecases.py +++ b/tests/test_usecases.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(UseCaseClient, "get") - def test_search_usecases(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_usecases(self, mock_request: MagicMock) -> None: """Test use case search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 8, "results": [ { @@ -40,7 +40,7 @@ def test_search_usecases(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 8) self.assertEqual(len(result["results"]), 1) self.assertEqual(result["results"][0]["title"], "Test Use Case") - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(UseCaseClient, "post") def test_get_usecase_by_id(self, mock_post: MagicMock) -> None: @@ -82,10 +82,10 @@ def test_get_organization_usecases(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(UseCaseClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test use case search with filters.""" - mock_get.return_value = {"total": 4, "results": []} + mock_request.return_value = {"total": 4, "results": []} result = self.client.search( query="monitoring", @@ -96,17 +96,17 @@ def test_search_with_filters(self, mock_get: MagicMock) -> None: ) self.assertEqual(result["total"], 4) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(UseCaseClient, "get") - def test_search_with_sorting(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_with_sorting(self, mock_request: MagicMock) -> None: """Test use case search with sorting.""" - mock_get.return_value = {"total": 2, "results": []} + mock_request.return_value = {"total": 2, "results": []} result = self.client.search(query="test", sort="completed_on", page=1, page_size=10) self.assertEqual(result["total"], 2) - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(UseCaseClient, "post") def test_graphql_error_handling(self, mock_post: MagicMock) -> None: @@ -118,21 +118,17 @@ def test_graphql_error_handling(self, mock_post: MagicMock) -> None: with self.assertRaises(DataSpaceAPIError): self.client.get_by_id(123) - @patch.object(UseCaseClient, "get") - def test_search_pagination(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_pagination(self, mock_request: MagicMock) -> None: """Test use case search with pagination.""" - mock_get.return_value = { - "total": 50, - "page": 2, - "page_size": 20, - "results": [], - } + mock_request.return_value = {"total": 50, "results": [], "page": 2, "page_size": 20} result = self.client.search(query="test", page=2, page_size=20) self.assertEqual(result["total"], 50) self.assertEqual(result["page"], 2) - mock_get.assert_called_once() + self.assertEqual(result["page_size"], 20) + mock_request.assert_called_once() if __name__ == "__main__": From 290f13ca8ce5b62a36159023e118140202d245ac Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 27 Nov 2025 18:27:53 +0530 Subject: [PATCH 027/127] fix typo on test --- tests/test_aimodels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aimodels.py b/tests/test_aimodels.py index eb6175e..600e63e 100644 --- a/tests/test_aimodels.py +++ b/tests/test_aimodels.py @@ -49,7 +49,7 @@ def test_get_model_by_id(self, mock_request: MagicMock) -> None: self.assertEqual(result["id"], "123") self.assertEqual(result["displayName"], "Test Model") - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(AIModelClient, "post") def test_get_model_by_id_graphql(self, mock_post: MagicMock) -> None: From 4d0f09db52635c28f99b5fdc7402e74349b009ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 27 Nov 2025 13:12:56 +0000 Subject: [PATCH 028/127] Bump SDK version to 0.3.3 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 5355cb4..51b7a2d 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.3.2" +__version__ = "0.3.3" diff --git a/pyproject.toml b/pyproject.toml index c8338b6..3fe3d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.3.2" +version = "0.3.3" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.3.2" +python_version = "0.3.3" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 145f530717320fa2112fe90fcacb7ca93c83bd1c Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 28 Nov 2025 11:18:45 +0530 Subject: [PATCH 029/127] fix aimodels list_all query --- api/schema/aimodel_schema.py | 74 +++++++++++------------------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index fe6b120..081d533 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -38,9 +38,7 @@ def _update_aimodel_tags(model: AIModel, tags: Optional[List[str]]) -> None: return model.tags.clear() for tag in tags: - model.tags.add( - Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0] - ) + model.tags.add(Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0]) model.save() @@ -190,9 +188,9 @@ def ai_models( queryset = AIModel.objects.all() else: # For authenticated users, show their models and public models - queryset = AIModel.objects.filter( - user=user - ) | AIModel.objects.filter(is_public=True, is_active=True) + queryset = AIModel.objects.filter(user=user) | AIModel.objects.filter( + is_public=True, is_active=True + ) else: # For non-authenticated users, only show public active models queryset = AIModel.objects.filter(is_public=True, is_active=True) @@ -203,10 +201,12 @@ def ai_models( if order is not strawberry.UNSET: queryset = strawberry_django.ordering.apply(order, queryset, info) + queryset = queryset.distinct() + if pagination is not strawberry.UNSET: queryset = strawberry_django.pagination.apply(pagination, queryset) - return TypeAIModel.from_django_list(list(queryset.distinct())) + return TypeAIModel.from_django_list(list(queryset)) @strawberry.field @trace_resolver(name="get_ai_model", attributes={"component": "aimodel"}) @@ -279,9 +279,7 @@ class Mutation: "get_data": lambda result, **kwargs: { "model_id": str(result.id), "model_name": result.name, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, }, ) @@ -348,9 +346,7 @@ def create_ai_model( "get_data": lambda result, **kwargs: { "model_id": str(result.id), "model_name": result.name, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, }, ) @@ -372,13 +368,9 @@ def update_ai_model( user=user, organization=model.organization ).first() if not org_member or not org_member.role.can_change: - raise DjangoValidationError( - "You don't have permission to update this model." - ) + raise DjangoValidationError("You don't have permission to update this model.") else: - raise DjangoValidationError( - "You don't have permission to update this model." - ) + raise DjangoValidationError("You don't have permission to update this model.") # Update fields if input.name is not None: @@ -459,13 +451,9 @@ def delete_ai_model(self, info: Info, model_id: int) -> MutationResponse[bool]: user=user, organization=model.organization ).first() if not org_member or not org_member.role.can_delete: - raise DjangoValidationError( - "You don't have permission to delete this model." - ) + raise DjangoValidationError("You don't have permission to delete this model.") else: - raise DjangoValidationError( - "You don't have permission to delete this model." - ) + raise DjangoValidationError("You don't have permission to delete this model.") model.delete() return MutationResponse.success_response(True) @@ -492,9 +480,7 @@ def create_model_endpoint( try: model = AIModel.objects.get(id=input.model_id) except AIModel.DoesNotExist: - raise DjangoValidationError( - f"AI Model with ID {input.model_id} does not exist." - ) + raise DjangoValidationError(f"AI Model with ID {input.model_id} does not exist.") # Check permissions if not user.is_superuser and model.user != user: @@ -513,9 +499,7 @@ def create_model_endpoint( # If this is primary, unset other primary endpoints if input.is_primary: - ModelEndpoint.objects.filter(model=model, is_primary=True).update( - is_primary=False - ) + ModelEndpoint.objects.filter(model=model, is_primary=True).update(is_primary=False) endpoint = ModelEndpoint.objects.create( model=model, @@ -533,9 +517,7 @@ def create_model_endpoint( is_active=input.is_active, ) - return MutationResponse.success_response( - TypeModelEndpoint.from_django(endpoint) - ) + return MutationResponse.success_response(TypeModelEndpoint.from_django(endpoint)) @strawberry.mutation @BaseMutation.mutation( @@ -559,9 +541,7 @@ def update_model_endpoint( try: endpoint = ModelEndpoint.objects.get(id=input.id) except ModelEndpoint.DoesNotExist: - raise DjangoValidationError( - f"Model Endpoint with ID {input.id} does not exist." - ) + raise DjangoValidationError(f"Model Endpoint with ID {input.id} does not exist.") model = endpoint.model @@ -576,9 +556,7 @@ def update_model_endpoint( "You don't have permission to update this endpoint." ) else: - raise DjangoValidationError( - "You don't have permission to update this endpoint." - ) + raise DjangoValidationError("You don't have permission to update this endpoint.") # Update fields if input.url is not None: @@ -612,9 +590,7 @@ def update_model_endpoint( endpoint.is_primary = True endpoint.save() - return MutationResponse.success_response( - TypeModelEndpoint.from_django(endpoint) - ) + return MutationResponse.success_response(TypeModelEndpoint.from_django(endpoint)) @strawberry.mutation @BaseMutation.mutation( @@ -629,18 +605,14 @@ def update_model_endpoint( }, }, ) - def delete_model_endpoint( - self, info: Info, endpoint_id: int - ) -> MutationResponse[bool]: + def delete_model_endpoint(self, info: Info, endpoint_id: int) -> MutationResponse[bool]: """Delete a model endpoint.""" user = info.context.user try: endpoint = ModelEndpoint.objects.get(id=endpoint_id) except ModelEndpoint.DoesNotExist: - raise DjangoValidationError( - f"Model Endpoint with ID {endpoint_id} does not exist." - ) + raise DjangoValidationError(f"Model Endpoint with ID {endpoint_id} does not exist.") model = endpoint.model @@ -655,9 +627,7 @@ def delete_model_endpoint( "You don't have permission to delete this endpoint." ) else: - raise DjangoValidationError( - "You don't have permission to delete this endpoint." - ) + raise DjangoValidationError("You don't have permission to delete this endpoint.") endpoint.delete() return MutationResponse.success_response(True) From aeb21bb82128731c0e6298b77541a1defdb4ac5a Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 28 Nov 2025 16:36:16 +0530 Subject: [PATCH 030/127] generate new unique name if name or description not passed --- api/schema/aimodel_schema.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index 081d533..dfdde3d 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -74,11 +74,11 @@ def _update_aimodel_geographies(model: AIModel, geographies: List[str]) -> None: class CreateAIModelInput: """Input for creating a new AI Model.""" - name: str - display_name: str - description: str model_type: AIModelTypeEnum provider: AIModelProviderEnum + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None version: Optional[str] = None provider_model_id: Optional[str] = None supports_streaming: bool = False @@ -290,6 +290,12 @@ def create_ai_model( organization = info.context.context.get("organization") user = info.context.user + # Generate default values if not provided (similar to dataset creation) + timestamp = datetime.datetime.now().strftime("%d %b %Y - %H:%M:%S") + name = input.name or f"untitled-ai-model-{timestamp}" + display_name = input.display_name or f"Untitled AI Model - {timestamp}" + description = input.description or "" + # Prepare supported_languages supported_languages = input.supported_languages or [] @@ -302,10 +308,10 @@ def create_ai_model( try: model = AIModel.objects.create( - name=input.name, - display_name=input.display_name, + name=name, + display_name=display_name, version=input.version or "", - description=input.description, + description=description, model_type=input.model_type, provider=input.provider, provider_model_id=input.provider_model_id or "", From 78a29b44a5bdbb5bd81bbdbd8585ac3d71245333 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 1 Dec 2025 15:55:13 +0530 Subject: [PATCH 031/127] remove uuid from aimodel urls --- api/urls.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/urls.py b/api/urls.py index 57f5c30..98a80a8 100644 --- a/api/urls.py +++ b/api/urls.py @@ -38,17 +38,17 @@ path("search/aimodel/", search_aimodel.SearchAIModel.as_view(), name="search_aimodel"), path("search/unified/", search_unified.UnifiedSearch.as_view(), name="search_unified"), path( - "aimodels//", + "aimodels//", aimodel_detail.AIModelDetailView.as_view(), name="aimodel_detail", ), path( - "aimodels//call/", + "aimodels//call/", aimodel_execution.call_aimodel, name="aimodel_call", ), path( - "aimodels//call-async/", + "aimodels//call-async/", aimodel_execution.call_aimodel_async, name="aimodel_call_async", ), From b4236503ebe1c6892b4db781bfcb67d2303ccbb7 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 23 Dec 2025 16:03:13 +0530 Subject: [PATCH 032/127] load env variables on run time --- DataSpace/settings.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/DataSpace/settings.py b/DataSpace/settings.py index feca514..b9a19bf 100644 --- a/DataSpace/settings.py +++ b/DataSpace/settings.py @@ -22,12 +22,15 @@ from .cache_settings import * -env = environ.Env(DEBUG=(bool, False)) -DEBUG = env.bool("DEBUG", default=True) # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent + +# Load .env file FIRST, before reading any env variables +env = environ.Env(DEBUG=(bool, False)) environ.Env.read_env(os.path.join(BASE_DIR, ".env")) +DEBUG = env.bool("DEBUG", default=True) + # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/ @@ -301,9 +304,7 @@ # Swagger Settings SWAGGER_SETTINGS = { - "SECURITY_DEFINITIONS": { - "Bearer": {"type": "apiKey", "name": "Authorization", "in": "header"} - } + "SECURITY_DEFINITIONS": {"Bearer": {"type": "apiKey", "name": "Authorization", "in": "header"}} } # Structured Logging Configuration @@ -370,17 +371,11 @@ # OpenTelemetry Sampling Configuration OTEL_TRACES_SAMPLER = "parentbased_traceidratio" -OTEL_TRACES_SAMPLER_ARG = os.getenv( - "OTEL_TRACES_SAMPLER_ARG", "1.0" -) # Sample 100% in dev +OTEL_TRACES_SAMPLER_ARG = os.getenv("OTEL_TRACES_SAMPLER_ARG", "1.0") # Sample 100% in dev # OpenTelemetry Metrics Configuration -OTEL_METRIC_EXPORT_INTERVAL_MILLIS = int( - os.getenv("OTEL_METRIC_EXPORT_INTERVAL_MILLIS", "30000") -) -OTEL_METRIC_EXPORT_TIMEOUT_MILLIS = int( - os.getenv("OTEL_METRIC_EXPORT_TIMEOUT_MILLIS", "30000") -) +OTEL_METRIC_EXPORT_INTERVAL_MILLIS = int(os.getenv("OTEL_METRIC_EXPORT_INTERVAL_MILLIS", "30000")) +OTEL_METRIC_EXPORT_TIMEOUT_MILLIS = int(os.getenv("OTEL_METRIC_EXPORT_TIMEOUT_MILLIS", "30000")) # OpenTelemetry Instrumentation Configuration OTEL_PYTHON_DJANGO_INSTRUMENT = True From fa76e4cf5ce6558ee469822a7c9792a5bbf281f4 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 23 Dec 2025 16:07:14 +0530 Subject: [PATCH 033/127] update user kc id if user with email already exists --- api/utils/keycloak_utils.py | 41 +++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 15f7f06..15cd6fd 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -366,17 +366,36 @@ def sync_user_from_keycloak( logger.error("Missing required user information from Keycloak") return None - # Get or create the user - user, created = User.objects.update_or_create( - keycloak_id=keycloak_id, - defaults={ - "username": username, - "email": email, - "first_name": user_info.get("given_name", ""), - "last_name": user_info.get("family_name", ""), - "is_active": True, - }, - ) + # First, try to find user by keycloak_id + user = User.objects.filter(keycloak_id=keycloak_id).first() + + if not user and email: + # If not found by keycloak_id, check if user exists with same email + # and update their keycloak_id (handles pre-existing users) + user = User.objects.filter(email=email).first() + if user: + logger.info( + f"Found existing user with email {email}, updating keycloak_id to {keycloak_id}" + ) + + if user: + # Update existing user + user.keycloak_id = keycloak_id + user.username = username + user.email = email + user.first_name = user_info.get("given_name", "") or user.first_name + user.last_name = user_info.get("family_name", "") or user.last_name + user.is_active = True + else: + # Create new user + user = User( + keycloak_id=keycloak_id, + username=username, + email=email, + first_name=user_info.get("given_name", ""), + last_name=user_info.get("family_name", ""), + is_active=True, + ) # Update user roles based on Keycloak roles if "admin" in roles: From 32c5be63a29bee6892a43f7465a030fb05452ecd Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 23 Dec 2025 17:11:11 +0530 Subject: [PATCH 034/127] add updated api model mutations --- api/models/AIModelVersion.py | 169 ++++++++++++++ api/models/__init__.py | 1 + api/schema/aimodel_schema.py | 436 +++++++++++++++++++++++++++++++++++ api/types/type_aimodel.py | 119 +++++++++- api/utils/enums.py | 1 + 5 files changed, 723 insertions(+), 3 deletions(-) create mode 100644 api/models/AIModelVersion.py diff --git a/api/models/AIModelVersion.py b/api/models/AIModelVersion.py new file mode 100644 index 0000000..f66893c --- /dev/null +++ b/api/models/AIModelVersion.py @@ -0,0 +1,169 @@ +"""AI Model Version model for version-specific configuration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.db import models + +from api.utils.enums import AIModelStatus + +if TYPE_CHECKING: + from django.db.models import QuerySet + + +class AIModelVersion(models.Model): + """ + Version of an AI Model with its own configuration. + Each version can have multiple providers. + """ + + ai_model = models.ForeignKey( + "api.AIModel", + on_delete=models.CASCADE, + related_name="versions", + ) + version = models.CharField(max_length=50, help_text="Version number (e.g., 1.0.0)") + version_notes = models.TextField(blank=True, help_text="Changelog/notes for this version") + + # Version-specific capabilities + supports_streaming = models.BooleanField(default=False) + max_tokens = models.IntegerField(null=True, blank=True, help_text="Maximum tokens supported") + supported_languages = models.JSONField( + default=list, help_text="List of supported language codes" + ) + input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters") + output_schema = models.JSONField(default=dict, help_text="Expected output format") + metadata = models.JSONField(default=dict, help_text="Additional version-specific metadata") + + # Status + status = models.CharField( + max_length=20, + choices=AIModelStatus.choices, + default=AIModelStatus.REGISTERED, + ) + is_latest = models.BooleanField(default=False, help_text="Whether this is the latest version") + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + published_at = models.DateTimeField(null=True, blank=True) + + class Meta: + unique_together = ["ai_model", "version"] + ordering = ["-created_at"] + verbose_name = "AI Model Version" + verbose_name_plural = "AI Model Versions" + + def __str__(self): + return f"{self.ai_model.name} v{self.version}" + + def save(self, *args, **kwargs): + # If this is set as latest, unset others + if self.is_latest: + AIModelVersion.objects.filter(ai_model=self.ai_model, is_latest=True).exclude( + pk=self.pk + ).update(is_latest=False) + super().save(*args, **kwargs) + + def copy_providers_from(self, source_version: AIModelVersion) -> None: + """ + Copy all providers from another version. + Used when creating a new version. + """ + for provider in source_version.providers.all(): # type: ignore[attr-defined] + # Create a copy of the provider + VersionProvider.objects.create( + version=self, + provider=provider.provider, # type: ignore[attr-defined] + provider_model_id=provider.provider_model_id, # type: ignore[attr-defined] + is_primary=provider.is_primary, # type: ignore[attr-defined] + is_active=provider.is_active, # type: ignore[attr-defined] + hf_use_pipeline=provider.hf_use_pipeline, # type: ignore[attr-defined] + hf_auth_token=provider.hf_auth_token, # type: ignore[attr-defined] + hf_model_class=provider.hf_model_class, # type: ignore[attr-defined] + hf_attn_implementation=provider.hf_attn_implementation, # type: ignore[attr-defined] + framework=provider.framework, # type: ignore[attr-defined] + config=provider.config, # type: ignore[attr-defined] + ) + + +class VersionProvider(models.Model): + """ + Provider configuration for a specific version. + A version can have multiple providers (HF, Custom, OpenAI, etc.) + Only ONE can be primary per version. + """ + + from api.utils.enums import AIModelFramework, AIModelProvider, HFModelClass + + version = models.ForeignKey( + AIModelVersion, + on_delete=models.CASCADE, + related_name="providers", + ) + + # Provider info + provider = models.CharField(max_length=50, choices=AIModelProvider.choices) + provider_model_id = models.CharField( + max_length=255, + blank=True, + help_text="Provider's model identifier (e.g., gpt-4, claude-3-opus)", + ) + is_primary = models.BooleanField( + default=False, help_text="Whether this is the primary provider for the version" + ) + is_active = models.BooleanField(default=True) + + # Huggingface-specific fields + hf_use_pipeline = models.BooleanField(default=False, help_text="Use Pipeline inference API") + hf_auth_token = models.CharField( + max_length=255, + blank=True, + null=True, + help_text="Huggingface Auth Token for gated models", + ) + hf_model_class = models.CharField( + max_length=100, + choices=HFModelClass.choices, + blank=True, + null=True, + help_text="Specify model head to use", + ) + hf_attn_implementation = models.CharField( + max_length=255, + blank=True, + default="flash_attention_2", + help_text="Attention Function", + ) + framework = models.CharField( + max_length=10, + choices=AIModelFramework.choices, + blank=True, + null=True, + help_text="Framework (PyTorch or TensorFlow)", + ) + + # Provider-specific configuration + config = models.JSONField(default=dict, help_text="Provider-specific configuration") + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["-is_primary", "-created_at"] + verbose_name = "Version Provider" + verbose_name_plural = "Version Providers" + + def __str__(self): + primary_str = " (Primary)" if self.is_primary else "" + return f"{self.version} - {self.provider}{primary_str}" + + def save(self, *args, **kwargs): + # Ensure only one primary per version + if self.is_primary: + VersionProvider.objects.filter(version=self.version, is_primary=True).exclude( + pk=self.pk + ).update(is_primary=False) + super().save(*args, **kwargs) diff --git a/api/models/__init__.py b/api/models/__init__.py index a93f804..809a711 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,5 +1,6 @@ from api.models.AccessModel import AccessModel, AccessModelResource from api.models.AIModel import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Catalog import Catalog from api.models.Collaborative import Collaborative from api.models.CollaborativeMetadata import CollaborativeMetadata diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index dfdde3d..ceb4a23 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -12,6 +12,7 @@ from strawberry_django.pagination import OffsetPaginationInput from api.models.AIModel import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Dataset import Tag from api.schema.base_mutation import BaseMutation, MutationResponse from api.schema.extensions import TrackActivity, TrackModelActivity @@ -21,10 +22,14 @@ AIModelProviderEnum, AIModelStatusEnum, AIModelTypeEnum, + AIModelVersionFilter, + AIModelVersionOrder, EndpointAuthTypeEnum, EndpointHTTPMethodEnum, TypeAIModel, + TypeAIModelVersion, TypeModelEndpoint, + TypeVersionProvider, ) from api.utils.graphql_telemetry import trace_resolver from authorization.graphql_permissions import IsAuthenticated @@ -157,6 +162,72 @@ class UpdateModelEndpointInput: rate_limit_per_minute: Optional[int] = None +@strawberry.input +class CreateAIModelVersionInput: + """Input for creating a new AI Model Version.""" + + model_id: int + version: str + version_notes: Optional[str] = "" + supports_streaming: bool = False + max_tokens: Optional[int] = None + supported_languages: Optional[List[str]] = None + input_schema: Optional[strawberry.scalars.JSON] = None + output_schema: Optional[strawberry.scalars.JSON] = None + metadata: Optional[strawberry.scalars.JSON] = None + copy_from_version_id: Optional[int] = None + + +@strawberry.input +class UpdateAIModelVersionInput: + """Input for updating an AI Model Version.""" + + id: int + version: Optional[str] = None + version_notes: Optional[str] = None + supports_streaming: Optional[bool] = None + max_tokens: Optional[int] = None + supported_languages: Optional[List[str]] = None + input_schema: Optional[strawberry.scalars.JSON] = None + output_schema: Optional[strawberry.scalars.JSON] = None + metadata: Optional[strawberry.scalars.JSON] = None + status: Optional[AIModelStatusEnum] = None + is_latest: Optional[bool] = None + + +@strawberry.input +class CreateVersionProviderInput: + """Input for creating a new Version Provider.""" + + version_id: int + provider: AIModelProviderEnum + provider_model_id: Optional[str] = "" + is_primary: bool = False + is_active: bool = True + hf_use_pipeline: bool = False + hf_auth_token: Optional[str] = None + hf_model_class: Optional[str] = None + hf_attn_implementation: Optional[str] = "flash_attention_2" + framework: Optional[str] = None + config: Optional[strawberry.scalars.JSON] = None + + +@strawberry.input +class UpdateVersionProviderInput: + """Input for updating a Version Provider.""" + + id: int + provider_model_id: Optional[str] = None + is_primary: Optional[bool] = None + is_active: Optional[bool] = None + hf_use_pipeline: Optional[bool] = None + hf_auth_token: Optional[str] = None + hf_model_class: Optional[str] = None + hf_attn_implementation: Optional[str] = None + framework: Optional[str] = None + config: Optional[strawberry.scalars.JSON] = None + + @strawberry.type class Query: """Queries for AI Models.""" @@ -637,3 +708,368 @@ def delete_model_endpoint(self, info: Info, endpoint_id: int) -> MutationRespons endpoint.delete() return MutationResponse.success_response(True) + + # ==================== VERSION MUTATIONS ==================== + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="create_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def create_ai_model_version( + self, info: Info, input: CreateAIModelVersionInput + ) -> MutationResponse[TypeAIModelVersion]: + """Create a new AI model version. Optionally copy providers from another version.""" + user = info.context.user + + try: + model = AIModel.objects.get(id=input.model_id) + except AIModel.DoesNotExist: + raise DjangoValidationError(f"AI Model with ID {input.model_id} does not exist.") + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to add versions to this model." + ) + else: + raise DjangoValidationError( + "You don't have permission to add versions to this model." + ) + + # Create the version + version = AIModelVersion.objects.create( + ai_model=model, + version=input.version, + version_notes=input.version_notes or "", + supports_streaming=input.supports_streaming, + max_tokens=input.max_tokens, + supported_languages=input.supported_languages or [], + input_schema=input.input_schema or {}, + output_schema=input.output_schema or {}, + metadata=input.metadata or {}, + status="DRAFT", + is_latest=True, + ) + + # If copy_from_version_id is provided, copy all providers + if input.copy_from_version_id: + try: + source_version = AIModelVersion.objects.get(id=input.copy_from_version_id) + version.copy_providers_from(source_version) + except AIModelVersion.DoesNotExist: + pass # Silently ignore if source version doesn't exist + + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="update_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def update_ai_model_version( + self, info: Info, input: UpdateAIModelVersionInput + ) -> MutationResponse[TypeAIModelVersion]: + """Update an AI model version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=input.id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {input.id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError("You don't have permission to update this version.") + else: + raise DjangoValidationError("You don't have permission to update this version.") + + # Update fields + if input.version is not None: + version.version = input.version + if input.version_notes is not None: + version.version_notes = input.version_notes + if input.supports_streaming is not None: + version.supports_streaming = input.supports_streaming + if input.max_tokens is not None: + version.max_tokens = input.max_tokens + if input.supported_languages is not None: + version.supported_languages = input.supported_languages + if input.input_schema is not None: + version.input_schema = input.input_schema + if input.output_schema is not None: + version.output_schema = input.output_schema + if input.metadata is not None: + version.metadata = input.metadata + if input.status is not None: + version.status = input.status + if input.is_latest is not None: + version.is_latest = input.is_latest + + version.save() + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="delete_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def delete_ai_model_version(self, info: Info, version_id: int) -> MutationResponse[bool]: + """Delete an AI model version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {version_id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_delete: + raise DjangoValidationError("You don't have permission to delete this version.") + else: + raise DjangoValidationError("You don't have permission to delete this version.") + + version.delete() + return MutationResponse.success_response(True) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="publish_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def publish_ai_model_version( + self, info: Info, version_id: int + ) -> MutationResponse[TypeAIModelVersion]: + """Publish an AI model version and set it as latest.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {version_id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to publish this version." + ) + else: + raise DjangoValidationError("You don't have permission to publish this version.") + + from django.utils import timezone + + version.status = "ACTIVE" + version.is_latest = True + version.published_at = timezone.now() + version.save() + + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + # ==================== PROVIDER MUTATIONS ==================== + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="create_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def create_version_provider( + self, info: Info, input: CreateVersionProviderInput + ) -> MutationResponse[TypeVersionProvider]: + """Create a new provider for a version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=input.version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError( + f"AI Model Version with ID {input.version_id} does not exist." + ) + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to add providers to this version." + ) + else: + raise DjangoValidationError( + "You don't have permission to add providers to this version." + ) + + provider = VersionProvider.objects.create( + version=version, + provider=input.provider, + provider_model_id=input.provider_model_id or "", + is_primary=input.is_primary, + is_active=input.is_active, + hf_use_pipeline=input.hf_use_pipeline, + hf_auth_token=input.hf_auth_token, + hf_model_class=input.hf_model_class, + hf_attn_implementation=input.hf_attn_implementation or "flash_attention_2", + framework=input.framework, + config=input.config or {}, + ) + + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="update_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def update_version_provider( + self, info: Info, input: UpdateVersionProviderInput + ) -> MutationResponse[TypeVersionProvider]: + """Update a version provider.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=input.id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {input.id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to update this provider." + ) + else: + raise DjangoValidationError("You don't have permission to update this provider.") + + # Update fields + if input.provider_model_id is not None: + provider.provider_model_id = input.provider_model_id + if input.is_primary is not None: + provider.is_primary = input.is_primary + if input.is_active is not None: + provider.is_active = input.is_active + if input.hf_use_pipeline is not None: + provider.hf_use_pipeline = input.hf_use_pipeline + if input.hf_auth_token is not None: + provider.hf_auth_token = input.hf_auth_token + if input.hf_model_class is not None: + provider.hf_model_class = input.hf_model_class + if input.hf_attn_implementation is not None: + provider.hf_attn_implementation = input.hf_attn_implementation + if input.framework is not None: + provider.framework = input.framework + if input.config is not None: + provider.config = input.config + + provider.save() + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="delete_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def delete_version_provider(self, info: Info, provider_id: int) -> MutationResponse[bool]: + """Delete a version provider.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=provider_id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {provider_id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_delete: + raise DjangoValidationError( + "You don't have permission to delete this provider." + ) + else: + raise DjangoValidationError("You don't have permission to delete this provider.") + + provider.delete() + return MutationResponse.success_response(True) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="set_primary_provider", + trace_attributes={"component": "aimodel"}, + ) + def set_primary_provider( + self, info: Info, provider_id: int + ) -> MutationResponse[TypeVersionProvider]: + """Set a provider as the primary provider for its version.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=provider_id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {provider_id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to update this provider." + ) + else: + raise DjangoValidationError("You don't have permission to update this provider.") + + provider.is_primary = True + provider.save() # The save method will unset other primaries + + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 5e6be99..1d3092e 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -11,17 +11,20 @@ from strawberry.types import Info from api.models.AIModel import AIModel, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.types.base_type import BaseType from api.types.type_dataset import TypeTag from api.types.type_geo import TypeGeo from api.types.type_organization import TypeOrganization from api.types.type_sector import TypeSector from api.utils.enums import ( + AIModelFramework, AIModelProvider, AIModelStatus, AIModelType, EndpointAuthType, EndpointHTTPMethod, + HFModelClass, ) from authorization.types import TypeUser @@ -34,6 +37,8 @@ AIModelProviderEnum = strawberry.enum(AIModelProvider) # type: ignore EndpointAuthTypeEnum = strawberry.enum(EndpointAuthType) # type: ignore EndpointHTTPMethodEnum = strawberry.enum(EndpointHTTPMethod) # type: ignore +AIModelFrameworkEnum = strawberry.enum(AIModelFramework) # type: ignore +HFModelClassEnum = strawberry.enum(HFModelClass) # type: ignore @strawberry.type @@ -65,9 +70,7 @@ def success_rate(self) -> Optional[float]: """Calculate success rate.""" if self.total_requests == 0: return None - return ( - (self.total_requests - self.failed_requests) / self.total_requests - ) * 100 + return ((self.total_requests - self.failed_requests) / self.total_requests) * 100 @strawberry_django.filter(AIModel) @@ -187,3 +190,113 @@ def primary_endpoint(self) -> Optional[TypeModelEndpoint]: if endpoint: return TypeModelEndpoint.from_django(endpoint) return None + + @strawberry.field(description="Get all versions of this AI model.") + def versions(self) -> List["TypeAIModelVersion"]: + """Get all versions of this AI model.""" + try: + queryset = self.versions.all() # type: ignore + return TypeAIModelVersion.from_django_list(list(queryset)) + except Exception: + return [] + + @strawberry.field(description="Get the latest version of this AI model.") + def latest_version(self) -> Optional["TypeAIModelVersion"]: + """Get the latest version of this AI model.""" + try: + version = self.versions.filter(is_latest=True).first() # type: ignore + if not version: + version = self.versions.order_by("-created_at").first() # type: ignore + if version: + return TypeAIModelVersion.from_django(version) + return None + except Exception: + return None + + +@strawberry.type +class TypeVersionProvider(BaseType): + """GraphQL type for VersionProvider.""" + + id: int + provider: AIModelProviderEnum + provider_model_id: Optional[str] + is_primary: bool + is_active: bool + hf_use_pipeline: bool + hf_auth_token: Optional[str] + hf_model_class: Optional[str] + hf_attn_implementation: Optional[str] + framework: Optional[str] + config: strawberry.scalars.JSON + created_at: datetime + updated_at: datetime + + +@strawberry_django.filter(AIModelVersion) +class AIModelVersionFilter: + """Filter for AI Model Version.""" + + id: Optional[int] + status: Optional[AIModelStatusEnum] + is_latest: Optional[bool] + + +@strawberry_django.order(AIModelVersion) +class AIModelVersionOrder: + """Order for AI Model Version.""" + + version: strawberry.auto + created_at: strawberry.auto + updated_at: strawberry.auto + + +@strawberry.type +class TypeAIModelVersion(BaseType): + """GraphQL type for AI Model Version.""" + + id: int + version: str + version_notes: Optional[str] + supports_streaming: bool + max_tokens: Optional[int] + supported_languages: strawberry.scalars.JSON + input_schema: strawberry.scalars.JSON + output_schema: strawberry.scalars.JSON + metadata: strawberry.scalars.JSON + status: AIModelStatusEnum + is_latest: bool + created_at: datetime + updated_at: datetime + published_at: Optional[datetime] + + @strawberry.field + def providers(self) -> List[TypeVersionProvider]: + """Get all providers for this version.""" + try: + django_instance = cast(AIModelVersion, self) + queryset = django_instance.providers.all() + return TypeVersionProvider.from_django_list(list(queryset)) + except Exception: + return [] + + @strawberry.field + def primary_provider(self) -> Optional[TypeVersionProvider]: + """Get the primary provider for this version.""" + try: + django_instance = cast(AIModelVersion, self) + provider = django_instance.providers.filter(is_primary=True).first() + if provider: + return TypeVersionProvider.from_django(provider) + return None + except Exception: + return None + + @strawberry.field + def ai_model(self) -> Optional[TypeAIModel]: + """Get the parent AI model.""" + try: + django_instance = cast(AIModelVersion, self) + return TypeAIModel.from_django(django_instance.ai_model) + except Exception: + return None diff --git a/api/utils/enums.py b/api/utils/enums.py index 18ccbab..bc2a258 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -169,6 +169,7 @@ class AIModelType(models.TextChoices): class AIModelStatus(models.TextChoices): + DRAFT = "DRAFT" REGISTERED = "REGISTERED" VALIDATING = "VALIDATING" ACTIVE = "ACTIVE" From c1ccddd932c985444cb7e8461ab730a5c81353cd Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 24 Dec 2025 18:45:20 +0530 Subject: [PATCH 035/127] accept list of ids instead of names --- api/schema/aimodel_schema.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index ceb4a23..205f24f 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -61,14 +61,14 @@ def _update_aimodel_sectors(model: AIModel, sectors: List[str]) -> None: model.save() -def _update_aimodel_geographies(model: AIModel, geographies: List[str]) -> None: +def _update_aimodel_geographies(model: AIModel, geographies: List[int]) -> None: """Helper function to update geographies for an AI model.""" from api.models import Geography model.geographies.clear() - for geography_name in geographies: + for geography_id in geographies: try: - geography = Geography.objects.get(name__iexact=geography_name) + geography = Geography.objects.get(id=geography_id) model.geographies.add(geography) except Geography.DoesNotExist: pass @@ -93,7 +93,7 @@ class CreateAIModelInput: output_schema: Optional[strawberry.scalars.JSON] = None tags: Optional[List[str]] = None sectors: Optional[List[str]] = None - geographies: Optional[List[str]] = None + geographies: Optional[List[int]] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: bool = False @@ -117,7 +117,7 @@ class UpdateAIModelInput: output_schema: Optional[strawberry.scalars.JSON] = None tags: Optional[List[str]] = None sectors: Optional[List[str]] = None - geographies: Optional[List[str]] = None + geographies: Optional[List[int]] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: Optional[bool] = None is_active: Optional[bool] = None From 7396f9f516f0e2c7a8778580977875fb8fa4a70f Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 26 Dec 2025 14:21:15 +0530 Subject: [PATCH 036/127] add lifecycle stage to aimodel --- api/models/AIModelVersion.py | 10 ++++++++-- api/schema/aimodel_schema.py | 9 ++++++++- api/types/type_aimodel.py | 11 ++++++++--- api/utils/enums.py | 10 ++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/api/models/AIModelVersion.py b/api/models/AIModelVersion.py index f66893c..f679502 100644 --- a/api/models/AIModelVersion.py +++ b/api/models/AIModelVersion.py @@ -6,7 +6,7 @@ from django.db import models -from api.utils.enums import AIModelStatus +from api.utils.enums import AIModelLifecycleStage, AIModelStatus if TYPE_CHECKING: from django.db.models import QuerySet @@ -36,12 +36,18 @@ class AIModelVersion(models.Model): output_schema = models.JSONField(default=dict, help_text="Expected output format") metadata = models.JSONField(default=dict, help_text="Additional version-specific metadata") - # Status + # Status & Lifecycle status = models.CharField( max_length=20, choices=AIModelStatus.choices, default=AIModelStatus.REGISTERED, ) + lifecycle_stage = models.CharField( + max_length=20, + choices=AIModelLifecycleStage.choices, + default=AIModelLifecycleStage.DEVELOPMENT, + help_text="Current lifecycle stage of this version", + ) is_latest = models.BooleanField(default=False, help_text="Whether this is the latest version") # Timestamps diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index 205f24f..1d9b8b7 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -18,6 +18,7 @@ from api.schema.extensions import TrackActivity, TrackModelActivity from api.types.type_aimodel import ( AIModelFilter, + AIModelLifecycleStageEnum, AIModelOrder, AIModelProviderEnum, AIModelStatusEnum, @@ -169,6 +170,7 @@ class CreateAIModelVersionInput: model_id: int version: str version_notes: Optional[str] = "" + lifecycle_stage: Optional[AIModelLifecycleStageEnum] = None supports_streaming: bool = False max_tokens: Optional[int] = None supported_languages: Optional[List[str]] = None @@ -176,6 +178,7 @@ class CreateAIModelVersionInput: output_schema: Optional[strawberry.scalars.JSON] = None metadata: Optional[strawberry.scalars.JSON] = None copy_from_version_id: Optional[int] = None + is_latest: Optional[bool] = None @strawberry.input @@ -185,6 +188,7 @@ class UpdateAIModelVersionInput: id: int version: Optional[str] = None version_notes: Optional[str] = None + lifecycle_stage: Optional[AIModelLifecycleStageEnum] = None supports_streaming: Optional[bool] = None max_tokens: Optional[int] = None supported_languages: Optional[List[str]] = None @@ -748,6 +752,7 @@ def create_ai_model_version( ai_model=model, version=input.version, version_notes=input.version_notes or "", + lifecycle_stage=input.lifecycle_stage.value if input.lifecycle_stage else "DEVELOPMENT", # type: ignore[misc] supports_streaming=input.supports_streaming, max_tokens=input.max_tokens, supported_languages=input.supported_languages or [], @@ -755,7 +760,7 @@ def create_ai_model_version( output_schema=input.output_schema or {}, metadata=input.metadata or {}, status="DRAFT", - is_latest=True, + is_latest=input.is_latest if input.is_latest is not None else True, ) # If copy_from_version_id is provided, copy all providers @@ -803,6 +808,8 @@ def update_ai_model_version( version.version = input.version if input.version_notes is not None: version.version_notes = input.version_notes + if input.lifecycle_stage is not None: + version.lifecycle_stage = input.lifecycle_stage.value # type: ignore[misc] if input.supports_streaming is not None: version.supports_streaming = input.supports_streaming if input.max_tokens is not None: diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 1d3092e..6105bb4 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -19,6 +19,7 @@ from api.types.type_sector import TypeSector from api.utils.enums import ( AIModelFramework, + AIModelLifecycleStage, AIModelProvider, AIModelStatus, AIModelType, @@ -39,6 +40,7 @@ EndpointHTTPMethodEnum = strawberry.enum(EndpointHTTPMethod) # type: ignore AIModelFrameworkEnum = strawberry.enum(AIModelFramework) # type: ignore HFModelClassEnum = strawberry.enum(HFModelClass) # type: ignore +AIModelLifecycleStageEnum = strawberry.enum(AIModelLifecycleStage) # type: ignore @strawberry.type @@ -195,7 +197,8 @@ def primary_endpoint(self) -> Optional[TypeModelEndpoint]: def versions(self) -> List["TypeAIModelVersion"]: """Get all versions of this AI model.""" try: - queryset = self.versions.all() # type: ignore + django_instance = cast(AIModel, self) + queryset = django_instance.versions.all() return TypeAIModelVersion.from_django_list(list(queryset)) except Exception: return [] @@ -204,9 +207,10 @@ def versions(self) -> List["TypeAIModelVersion"]: def latest_version(self) -> Optional["TypeAIModelVersion"]: """Get the latest version of this AI model.""" try: - version = self.versions.filter(is_latest=True).first() # type: ignore + django_instance = cast(AIModel, self) + version = django_instance.versions.filter(is_latest=True).first() if not version: - version = self.versions.order_by("-created_at").first() # type: ignore + version = django_instance.versions.order_by("-created_at").first() if version: return TypeAIModelVersion.from_django(version) return None @@ -265,6 +269,7 @@ class TypeAIModelVersion(BaseType): output_schema: strawberry.scalars.JSON metadata: strawberry.scalars.JSON status: AIModelStatusEnum + lifecycle_stage: AIModelLifecycleStageEnum is_latest: bool created_at: datetime updated_at: datetime diff --git a/api/utils/enums.py b/api/utils/enums.py index bc2a258..770ad38 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -189,6 +189,16 @@ class AIModelProvider(models.TextChoices): HUGGINGFACE = "HUGGINGFACE" +class AIModelLifecycleStage(models.TextChoices): + DEVELOPMENT = "DEVELOPMENT", "Development" + TESTING = "TESTING", "Testing" + BETA = "BETA", "Beta Testing" + STAGING = "STAGING", "Staging" + PRODUCTION = "PRODUCTION", "Production" + DEPRECATED = "DEPRECATED", "Deprecated" + RETIRED = "RETIRED", "Retired" + + class AIModelFramework(models.TextChoices): PYTORCH = "pt", "PyTorch" TENSORFLOW = "tf", "TensorFlow" From 9b32577d5afb99ce9d3e9816dc2689a9d5b5bfa4 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 29 Dec 2025 13:06:53 +0530 Subject: [PATCH 037/127] update sdk to reflect new aimodel schema --- dataspace_sdk/__version__.py | 2 +- dataspace_sdk/resources/aimodels.py | 547 ++++++++++++++++++++++++++-- 2 files changed, 527 insertions(+), 22 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 51b7a2d..16ce66b 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.3.3" +__version__ = "0.4.0" diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 84c4262..57536ba 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -94,19 +94,6 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: displayName description modelType - provider - version - providerModelId - hfUsePipeline - hfAuthToken - hfModelClass - hfAttnImplementation - framework - supportsStreaming - maxTokens - supportedLanguages - inputSchema - outputSchema status isPublic createdAt @@ -127,13 +114,34 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: id name } - endpoints { + versions { id - name - url - httpMethod - authType - isActive + version + versionNotes + lifecycleStage + isLatest + supportsStreaming + maxTokens + supportedLanguages + inputSchema + outputSchema + status + createdAt + updatedAt + publishedAt + providers { + id + provider + providerModelId + isPrimary + isActive + hfUsePipeline + hfAuthToken + hfModelClass + hfAttnImplementation + framework + config + } } } } @@ -184,8 +192,6 @@ def list_all( displayName description modelType - provider - version status isPublic createdAt @@ -198,6 +204,19 @@ def list_all( id value } + versions { + id + version + lifecycleStage + isLatest + status + providers { + id + provider + providerModelId + isPrimary + } + } } } """ @@ -343,3 +362,489 @@ def call_model_async( f"/api/aimodels/{model_id}/call-async/", json_data={"input_text": input_text, "parameters": parameters or {}}, ) + + # ==================== Version Management ==================== + + def get_versions(self, model_id: int) -> List[Dict[str, Any]]: + """ + Get all versions for an AI model. + + Args: + model_id: ID of the AI model + + Returns: + List of version dictionaries + """ + query = """ + query GetModelVersions($filters: AIModelFilter) { + aiModels(filters: $filters) { + versions { + id + version + versionNotes + lifecycleStage + isLatest + supportsStreaming + maxTokens + supportedLanguages + status + createdAt + updatedAt + publishedAt + providers { + id + provider + providerModelId + isPrimary + isActive + } + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"filters": {"id": model_id}}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + models = response.get("data", {}).get("aiModels", []) + if models: + result: List[Dict[str, Any]] = models[0].get("versions", []) + return result + return [] + + def create_version( + self, + model_id: int, + version: str, + lifecycle_stage: str = "DEVELOPMENT", + is_latest: bool = False, + copy_from_version_id: Optional[int] = None, + version_notes: Optional[str] = None, + supports_streaming: bool = False, + max_tokens: Optional[int] = None, + supported_languages: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Create a new version for an AI model. + + Args: + model_id: ID of the AI model + version: Version string (e.g., "1.0", "2.1") + lifecycle_stage: One of DEVELOPMENT, TESTING, BETA, STAGING, PRODUCTION, DEPRECATED, RETIRED + is_latest: Whether this should be the primary version + copy_from_version_id: Optional version ID to copy providers from + version_notes: Optional notes about this version + supports_streaming: Whether this version supports streaming + max_tokens: Maximum tokens supported + supported_languages: List of supported language codes + + Returns: + Dictionary containing created version information + """ + mutation = """ + mutation CreateAIModelVersion($input: CreateAIModelVersionInput!) { + createAiModelVersion(input: $input) { + success + data { + id + version + lifecycleStage + isLatest + status + } + errors + } + } + """ + + input_data: Dict[str, Any] = { + "modelId": model_id, + "version": version, + "lifecycleStage": lifecycle_stage, + "isLatest": is_latest, + "supportsStreaming": supports_streaming, + } + + if copy_from_version_id: + input_data["copyFromVersionId"] = copy_from_version_id + if version_notes: + input_data["versionNotes"] = version_notes + if max_tokens: + input_data["maxTokens"] = max_tokens + if supported_languages: + input_data["supportedLanguages"] = supported_languages + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("createAiModelVersion", {}) + return result + + def update_version( + self, + version_id: int, + version: Optional[str] = None, + lifecycle_stage: Optional[str] = None, + is_latest: Optional[bool] = None, + version_notes: Optional[str] = None, + status: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Update an AI model version. + + Args: + version_id: ID of the version to update + version: New version string + lifecycle_stage: New lifecycle stage + is_latest: Whether this should be the primary version + version_notes: New version notes + status: New status + + Returns: + Dictionary containing updated version information + """ + mutation = """ + mutation UpdateAIModelVersion($input: UpdateAIModelVersionInput!) { + updateAiModelVersion(input: $input) { + success + data { + id + version + lifecycleStage + isLatest + status + } + errors + } + } + """ + + input_data: Dict[str, Any] = {"id": version_id} + + if version is not None: + input_data["version"] = version + if lifecycle_stage is not None: + input_data["lifecycleStage"] = lifecycle_stage + if is_latest is not None: + input_data["isLatest"] = is_latest + if version_notes is not None: + input_data["versionNotes"] = version_notes + if status is not None: + input_data["status"] = status + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updateAiModelVersion", {}) + return result + + # ==================== Provider Management ==================== + + def get_version_providers(self, version_id: int) -> List[Dict[str, Any]]: + """ + Get all providers for a specific version. + + Args: + version_id: ID of the version + + Returns: + List of provider dictionaries + """ + query = """ + query GetVersionProviders($versionId: Int!) { + aiModelVersion(id: $versionId) { + providers { + id + provider + providerModelId + isPrimary + isActive + hfUsePipeline + hfAuthToken + hfModelClass + hfAttnImplementation + framework + config + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"versionId": version_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + version_data = response.get("data", {}).get("aiModelVersion", {}) + result: List[Dict[str, Any]] = version_data.get("providers", []) if version_data else [] + return result + + def create_provider( + self, + version_id: int, + provider: str, + provider_model_id: str, + is_primary: bool = False, + hf_use_pipeline: bool = False, + hf_model_class: Optional[str] = None, + hf_auth_token: Optional[str] = None, + hf_attn_implementation: Optional[str] = None, + framework: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Create a new provider for a version. + + Args: + version_id: ID of the version + provider: Provider type (OPENAI, LLAMA_OLLAMA, LLAMA_TOGETHER, LLAMA_REPLICATE, + LLAMA_CUSTOM, CUSTOM, HUGGINGFACE) + provider_model_id: Model ID at the provider (e.g., "gpt-4", "meta-llama/Llama-2-7b") + is_primary: Whether this is the primary provider + hf_use_pipeline: For HuggingFace - whether to use pipeline API + hf_model_class: For HuggingFace - model class (e.g., "AutoModelForCausalLM") + hf_auth_token: For HuggingFace - auth token for gated models + hf_attn_implementation: For HuggingFace - attention implementation + framework: Framework (pt, tf) + config: Additional configuration (apiKey, baseUrl, authType, etc.) + + Returns: + Dictionary containing created provider information + """ + mutation = """ + mutation CreateVersionProvider($input: CreateVersionProviderInput!) { + createVersionProvider(input: $input) { + success + data { + id + provider + providerModelId + isPrimary + isActive + } + errors + } + } + """ + + input_data: Dict[str, Any] = { + "versionId": version_id, + "provider": provider, + "providerModelId": provider_model_id, + "isPrimary": is_primary, + "hfUsePipeline": hf_use_pipeline, + } + + if hf_model_class: + input_data["hfModelClass"] = hf_model_class + if hf_auth_token: + input_data["hfAuthToken"] = hf_auth_token + if hf_attn_implementation: + input_data["hfAttnImplementation"] = hf_attn_implementation + if framework: + input_data["framework"] = framework + if config: + input_data["config"] = config + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("createVersionProvider", {}) + return result + + def update_provider( + self, + provider_id: int, + provider_model_id: Optional[str] = None, + is_primary: Optional[bool] = None, + hf_use_pipeline: Optional[bool] = None, + hf_model_class: Optional[str] = None, + hf_auth_token: Optional[str] = None, + hf_attn_implementation: Optional[str] = None, + framework: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Update a provider. + + Args: + provider_id: ID of the provider to update + provider_model_id: New model ID at the provider + is_primary: Whether this is the primary provider + hf_use_pipeline: For HuggingFace - whether to use pipeline API + hf_model_class: For HuggingFace - model class + hf_auth_token: For HuggingFace - auth token + hf_attn_implementation: For HuggingFace - attention implementation + framework: Framework (pt, tf) + config: Additional configuration + + Returns: + Dictionary containing updated provider information + """ + mutation = """ + mutation UpdateVersionProvider($input: UpdateVersionProviderInput!) { + updateVersionProvider(input: $input) { + success + data { + id + provider + providerModelId + isPrimary + isActive + } + errors + } + } + """ + + input_data: Dict[str, Any] = {"id": provider_id} + + if provider_model_id is not None: + input_data["providerModelId"] = provider_model_id + if is_primary is not None: + input_data["isPrimary"] = is_primary + if hf_use_pipeline is not None: + input_data["hfUsePipeline"] = hf_use_pipeline + if hf_model_class is not None: + input_data["hfModelClass"] = hf_model_class + if hf_auth_token is not None: + input_data["hfAuthToken"] = hf_auth_token + if hf_attn_implementation is not None: + input_data["hfAttnImplementation"] = hf_attn_implementation + if framework is not None: + input_data["framework"] = framework + if config is not None: + input_data["config"] = config + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updateVersionProvider", {}) + return result + + def delete_provider(self, provider_id: int) -> Dict[str, Any]: + """ + Delete a provider. + + Args: + provider_id: ID of the provider to delete + + Returns: + Dictionary containing deletion response + """ + mutation = """ + mutation DeleteVersionProvider($providerId: Int!) { + deleteVersionProvider(providerId: $providerId) { + success + errors + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"providerId": provider_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("deleteVersionProvider", {}) + return result + + # ==================== Helper Methods ==================== + + def get_primary_version(self, model_id: int) -> Optional[Dict[str, Any]]: + """ + Get the primary (latest) version for an AI model. + + Args: + model_id: ID of the AI model + + Returns: + Dictionary containing the primary version, or None if no versions exist + """ + versions = self.get_versions(model_id) + for version in versions: + if version.get("isLatest"): + return version + return versions[0] if versions else None + + def get_primary_provider(self, version_id: int) -> Optional[Dict[str, Any]]: + """ + Get the primary provider for a version. + + Args: + version_id: ID of the version + + Returns: + Dictionary containing the primary provider, or None if no providers exist + """ + providers = self.get_version_providers(version_id) + for provider in providers: + if provider.get("isPrimary"): + return provider + return providers[0] if providers else None From 42c2e97925f9c5990d670994bd26995f40306d1e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Dec 2025 08:58:42 +0000 Subject: [PATCH 038/127] Bump SDK version to 0.4.0 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3fe3d32..97ef11a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.3.3" +version = "0.4.0" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.3.3" +python_version = "0.4.0" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 8dc944911ee2242b15b86043492d4cdffbbb86ce Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 30 Dec 2025 12:54:17 +0530 Subject: [PATCH 039/127] update ai model elastic document --- api/views/search_aimodel.py | 63 +++++++-- search/documents/aimodel_document.py | 187 +++++++++++++++++++++++++-- 2 files changed, 225 insertions(+), 25 deletions(-) diff --git a/api/views/search_aimodel.py b/api/views/search_aimodel.py index 63f8b7b..b7fc606 100644 --- a/api/views/search_aimodel.py +++ b/api/views/search_aimodel.py @@ -21,10 +21,15 @@ class AIModelDocumentSerializer(serializers.ModelSerializer): """Serializer for AIModel document.""" tags = serializers.ListField() + sectors = serializers.ListField() + geographies = serializers.ListField() supported_languages = serializers.ListField() is_individual_model = serializers.BooleanField() has_active_endpoints = serializers.BooleanField() endpoint_count = serializers.IntegerField() + version_count = serializers.IntegerField() + lifecycle_stage = serializers.CharField() + all_providers = serializers.ListField() class OrganizationSerializer(serializers.Serializer): name = serializers.CharField() @@ -42,9 +47,31 @@ class EndpointSerializer(serializers.Serializer): is_primary = serializers.BooleanField() is_active = serializers.BooleanField() + class ProviderSerializer(serializers.Serializer): + id = serializers.IntegerField() + provider = serializers.CharField() + provider_model_id = serializers.CharField() + is_primary = serializers.BooleanField() + is_active = serializers.BooleanField() + + class VersionSerializer(serializers.Serializer): + id = serializers.IntegerField() + version = serializers.CharField() + version_notes = serializers.CharField(allow_blank=True) + lifecycle_stage = serializers.CharField() + is_latest = serializers.BooleanField() + status = serializers.CharField() + supports_streaming = serializers.BooleanField() + max_tokens = serializers.IntegerField(allow_null=True) + supported_languages = serializers.ListField() + created_at = serializers.DateTimeField() + updated_at = serializers.DateTimeField() + providers = serializers.ListField() + organization = OrganizationSerializer(allow_null=True) user = UserSerializer(allow_null=True) endpoints = EndpointSerializer(many=True) + versions = VersionSerializer(many=True) class Meta: model = AIModel @@ -61,6 +88,8 @@ class Meta: "is_public", "is_active", "tags", + "sectors", + "geographies", "supported_languages", "supports_streaming", "max_tokens", @@ -74,9 +103,13 @@ class Meta: "is_individual_model", "has_active_endpoints", "endpoint_count", + "version_count", + "lifecycle_stage", + "all_providers", "organization", "user", "endpoints", + "versions", ] @@ -91,9 +124,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -115,10 +146,14 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: "provider": "terms", "status": "terms", "tags.raw": "terms", + "sectors.raw": "terms", + "geographies.raw": "terms", "supported_languages": "terms", "is_public": "terms", "is_active": "terms", "supports_streaming": "terms", + "lifecycle_stage": "terms", + "all_providers": "terms", } return searchable_fields, aggregations @@ -134,19 +169,13 @@ def add_aggregations(self, search: Search) -> Search: ) return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_aimodel"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_aimodel"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] for field in self.searchable_fields: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -164,11 +193,21 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter("terms", **{raw_filter: filter_values}) else: search = search.filter("term", **{filter_key: filters[filter_key]}) + elif filter_key in ["sectors", "geographies"]: + # Handle multi-value filters for sectors and geographies + raw_filter = filter_key + ".raw" + if raw_filter in self.aggregations: + filter_values = filters[filter_key].split(",") + search = search.filter("terms", **{raw_filter: filter_values}) + else: + search = search.filter("term", **{filter_key: filters[filter_key]}) elif filter_key in [ "model_type", "provider", "status", "supported_languages", + "lifecycle_stage", + "all_providers", ]: # Handle single or multi-value filters filter_values = filters[filter_key].split(",") diff --git a/search/documents/aimodel_document.py b/search/documents/aimodel_document.py index 164f2d5..4566de5 100644 --- a/search/documents/aimodel_document.py +++ b/search/documents/aimodel_document.py @@ -5,6 +5,7 @@ from django_elasticsearch_dsl import Document, Index, KeywordField, fields from api.models.AIModel import AIModel, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Dataset import Tag from api.models.Geography import Geography from api.models.Organization import Organization @@ -45,11 +46,14 @@ class AIModelDocument(Document): ) version = fields.KeywordField() + provider = fields.KeywordField() + provider_model_id = fields.KeywordField() + supported_languages = fields.KeywordField(multi=True) + supports_streaming = fields.BooleanField() + max_tokens = fields.IntegerField() # Model configuration model_type = fields.KeywordField() - provider = fields.KeywordField() - provider_model_id = fields.KeywordField() # Status and visibility status = fields.KeywordField() @@ -89,13 +93,6 @@ class AIModelDocument(Document): multi=True, ) - # Supported languages (stored as JSON array in model) - supported_languages = fields.KeywordField(multi=True) - - # Capabilities - supports_streaming = fields.BooleanField() - max_tokens = fields.IntegerField() - # Performance metrics average_latency_ms = fields.FloatField() success_rate = fields.FloatField() @@ -130,10 +127,38 @@ class AIModelDocument(Document): } ) + versions = fields.NestedField( + properties={ + "id": fields.IntegerField(), + "version": fields.KeywordField(), + "version_notes": fields.TextField(analyzer=html_strip), + "lifecycle_stage": fields.KeywordField(), + "is_latest": fields.BooleanField(), + "status": fields.KeywordField(), + "supports_streaming": fields.BooleanField(), + "max_tokens": fields.IntegerField(), + "supported_languages": fields.KeywordField(multi=True), + "created_at": fields.DateField(), + "updated_at": fields.DateField(), + "providers": fields.NestedField( + properties={ + "id": fields.IntegerField(), + "provider": fields.KeywordField(), + "provider_model_id": fields.KeywordField(), + "is_primary": fields.BooleanField(), + "is_active": fields.BooleanField(), + } + ), + } + ) + # Computed fields is_individual_model = fields.BooleanField() has_active_endpoints = fields.BooleanField() endpoint_count = fields.IntegerField() + version_count = fields.IntegerField() + lifecycle_stage = fields.KeywordField() # Primary version's lifecycle stage + all_providers = fields.KeywordField(multi=True) # All unique providers across versions def prepare_organization(self, instance: AIModel) -> Optional[Dict[str, str]]: """Prepare organization data for indexing, including logo URL.""" @@ -150,9 +175,7 @@ def prepare_user(self, instance: AIModel) -> Optional[Dict[str, str]]: "name": instance.user.full_name, "bio": instance.user.bio or "", "profile_picture": ( - instance.user.profile_picture.url - if instance.user.profile_picture - else "" + instance.user.profile_picture.url if instance.user.profile_picture else "" ), } return None @@ -184,6 +207,131 @@ def prepare_endpoint_count(self, instance: AIModel) -> int: """Count the number of endpoints.""" return instance.endpoints.count() + def prepare_versions(self, instance: AIModel) -> List[Dict[str, Any]]: + """Prepare versions data for indexing.""" + versions_data: List[Dict[str, Any]] = [] + for version in instance.versions.all(): # type: ignore[attr-defined] + version_obj: AIModelVersion = version # type: ignore[assignment] + providers_data: List[Dict[str, Any]] = [] + for provider in version_obj.providers.all(): # type: ignore[attr-defined] + provider_obj: VersionProvider = provider # type: ignore[assignment] + providers_data.append( + { + "id": provider_obj.id, + "provider": provider_obj.provider, + "provider_model_id": provider_obj.provider_model_id, + "is_primary": provider_obj.is_primary, + "is_active": provider_obj.is_active, + } + ) + versions_data.append( + { + "id": version_obj.id, + "version": version_obj.version, + "version_notes": version_obj.version_notes or "", + "lifecycle_stage": version_obj.lifecycle_stage, + "is_latest": version_obj.is_latest, + "status": version_obj.status, + "supports_streaming": version_obj.supports_streaming, + "max_tokens": version_obj.max_tokens, + "supported_languages": version_obj.supported_languages or [], + "created_at": version_obj.created_at, + "updated_at": version_obj.updated_at, + "providers": providers_data, + } + ) + return versions_data + + def _get_primary_version(self, instance: AIModel) -> Optional[AIModelVersion]: + """Get the primary (latest) version of the model.""" + primary = instance.versions.filter(is_latest=True).first() # type: ignore[attr-defined] + if not primary: + primary = instance.versions.first() # type: ignore[attr-defined] + return primary # type: ignore[return-value] + + def _get_primary_provider(self, version: Optional[AIModelVersion]) -> Optional[VersionProvider]: + """Get the primary provider of a version.""" + if not version: + return None + primary = version.providers.filter(is_primary=True).first() # type: ignore[attr-defined] + if not primary: + primary = version.providers.first() # type: ignore[attr-defined] + return primary # type: ignore[return-value] + + def prepare_version(self, instance: AIModel) -> str: + """Prepare version from primary version for backward compatibility.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return str(primary_version.version) + # Fallback to legacy field on AIModel + return instance.version or "" + + def prepare_provider(self, instance: AIModel) -> str: + """Prepare provider from primary version's primary provider for backward compatibility.""" + primary_version = self._get_primary_version(instance) + primary_provider = self._get_primary_provider(primary_version) + if primary_provider: + return str(primary_provider.provider) + # Fallback to legacy field on AIModel + return instance.provider or "" + + def prepare_provider_model_id(self, instance: AIModel) -> str: + """Prepare provider_model_id from primary version's primary provider.""" + primary_version = self._get_primary_version(instance) + primary_provider = self._get_primary_provider(primary_version) + if primary_provider: + return primary_provider.provider_model_id or "" + # Fallback to legacy field on AIModel + return instance.provider_model_id or "" + + def prepare_supported_languages(self, instance: AIModel) -> List[str]: + """Prepare supported_languages from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version and primary_version.supported_languages: + return list(primary_version.supported_languages) + # Fallback to legacy field on AIModel + return list(instance.supported_languages or []) + + def prepare_supports_streaming(self, instance: AIModel) -> bool: + """Prepare supports_streaming from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return bool(primary_version.supports_streaming) + # Fallback to legacy field on AIModel + return bool(instance.supports_streaming) + + def prepare_max_tokens(self, instance: AIModel) -> Optional[int]: + """Prepare max_tokens from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version and primary_version.max_tokens: + return int(primary_version.max_tokens) + # Fallback to legacy field on AIModel + return instance.max_tokens + + def prepare_version_count(self, instance: AIModel) -> int: + """Count the number of versions.""" + return instance.versions.count() # type: ignore[attr-defined] + + def prepare_lifecycle_stage(self, instance: AIModel) -> str: + """Get lifecycle stage from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return str(primary_version.lifecycle_stage) + return "DEVELOPMENT" + + def prepare_all_providers(self, instance: AIModel) -> List[str]: + """Get all unique providers across all versions.""" + providers: set[str] = set() + for version in instance.versions.all(): # type: ignore[attr-defined] + version_obj: AIModelVersion = version # type: ignore[assignment] + for provider in version_obj.providers.all(): # type: ignore[attr-defined] + provider_obj: VersionProvider = provider # type: ignore[assignment] + providers.add(str(provider_obj.provider)) + # Also include legacy provider if set + if instance.provider: + providers.add(str(instance.provider)) + return list(providers) + def should_index_object(self, obj: AIModel) -> bool: """ Check if the object should be indexed. @@ -225,7 +373,14 @@ def get_queryset(self) -> Any: def get_instances_from_related( self, related_instance: Union[ - ModelEndpoint, Organization, User, Tag, Sector, Geography + ModelEndpoint, + Organization, + User, + Tag, + Sector, + Geography, + AIModelVersion, + VersionProvider, ], ) -> Optional[Union[AIModel, List[AIModel]]]: """Get AIModel instances from related models.""" @@ -241,6 +396,10 @@ def get_instances_from_related( return list(related_instance.ai_models.all()) elif isinstance(related_instance, Geography): return list(related_instance.ai_models.all()) + elif isinstance(related_instance, AIModelVersion): + return related_instance.ai_model + elif isinstance(related_instance, VersionProvider): + return related_instance.version.ai_model return None class Django: @@ -262,4 +421,6 @@ class Django: Tag, Sector, Geography, + AIModelVersion, + VersionProvider, ] From df53a0dcf7fa0ba1dec2fe5c0a550f1a41d361e5 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 30 Dec 2025 21:00:35 +0530 Subject: [PATCH 040/127] add provider specific fields to enable api based access via these providers and update corresponding schema --- api/models/AIModelVersion.py | 117 +++++++++++++++++++- api/schema/aimodel_schema.py | 100 ++++++++++++++++++ api/services/model_api_client.py | 176 +++++++++++++++++-------------- api/services/model_hf_client.py | 116 +++++++++++++------- api/types/type_aimodel.py | 23 ++++ 5 files changed, 411 insertions(+), 121 deletions(-) diff --git a/api/models/AIModelVersion.py b/api/models/AIModelVersion.py index f679502..e9b108d 100644 --- a/api/models/AIModelVersion.py +++ b/api/models/AIModelVersion.py @@ -81,15 +81,34 @@ def copy_providers_from(self, source_version: AIModelVersion) -> None: # Create a copy of the provider VersionProvider.objects.create( version=self, + # Provider info provider=provider.provider, # type: ignore[attr-defined] provider_model_id=provider.provider_model_id, # type: ignore[attr-defined] is_primary=provider.is_primary, # type: ignore[attr-defined] is_active=provider.is_active, # type: ignore[attr-defined] + # API Endpoint Configuration + api_endpoint_url=provider.api_endpoint_url, # type: ignore[attr-defined] + api_http_method=provider.api_http_method, # type: ignore[attr-defined] + api_timeout_seconds=provider.api_timeout_seconds, # type: ignore[attr-defined] + # Authentication Configuration + api_auth_type=provider.api_auth_type, # type: ignore[attr-defined] + api_auth_header_name=provider.api_auth_header_name, # type: ignore[attr-defined] + api_key=provider.api_key, # type: ignore[attr-defined] + api_key_prefix=provider.api_key_prefix, # type: ignore[attr-defined] + # Request/Response Configuration + api_headers=provider.api_headers, # type: ignore[attr-defined] + api_request_template=provider.api_request_template, # type: ignore[attr-defined] + api_response_path=provider.api_response_path, # type: ignore[attr-defined] + # HuggingFace fields hf_use_pipeline=provider.hf_use_pipeline, # type: ignore[attr-defined] hf_auth_token=provider.hf_auth_token, # type: ignore[attr-defined] hf_model_class=provider.hf_model_class, # type: ignore[attr-defined] hf_attn_implementation=provider.hf_attn_implementation, # type: ignore[attr-defined] + hf_trust_remote_code=provider.hf_trust_remote_code, # type: ignore[attr-defined] + hf_torch_dtype=provider.hf_torch_dtype, # type: ignore[attr-defined] + hf_device_map=provider.hf_device_map, # type: ignore[attr-defined] framework=provider.framework, # type: ignore[attr-defined] + # Additional config config=provider.config, # type: ignore[attr-defined] ) @@ -101,7 +120,13 @@ class VersionProvider(models.Model): Only ONE can be primary per version. """ - from api.utils.enums import AIModelFramework, AIModelProvider, HFModelClass + from api.utils.enums import ( + AIModelFramework, + AIModelProvider, + EndpointAuthType, + EndpointHTTPMethod, + HFModelClass, + ) version = models.ForeignKey( AIModelVersion, @@ -121,7 +146,75 @@ class VersionProvider(models.Model): ) is_active = models.BooleanField(default=True) + # ============================================ + # API Endpoint Configuration (for API-based providers) + # ============================================ + api_endpoint_url = models.URLField( + max_length=500, + blank=True, + null=True, + help_text="API endpoint URL (e.g., https://api.openai.com/v1/chat/completions)", + ) + api_http_method = models.CharField( + max_length=10, + choices=EndpointHTTPMethod.choices, + default=EndpointHTTPMethod.POST, + help_text="HTTP method for API calls", + ) + api_timeout_seconds = models.IntegerField( + default=60, + help_text="Request timeout in seconds", + ) + + # ============================================ + # Authentication Configuration + # ============================================ + api_auth_type = models.CharField( + max_length=20, + choices=EndpointAuthType.choices, + default=EndpointAuthType.BEARER, + help_text="Authentication type for API calls", + ) + api_auth_header_name = models.CharField( + max_length=100, + default="Authorization", + help_text="Header name for authentication (e.g., Authorization, X-API-Key)", + ) + api_key = models.CharField( + max_length=500, + blank=True, + null=True, + help_text="API key or token (encrypted at rest)", + ) + api_key_prefix = models.CharField( + max_length=50, + blank=True, + default="Bearer", + help_text="Prefix for API key (e.g., Bearer, Token)", + ) + + # ============================================ + # Request/Response Configuration + # ============================================ + api_headers = models.JSONField( + default=dict, + blank=True, + help_text="Additional headers to include in requests", + ) + api_request_template = models.JSONField( + default=dict, + blank=True, + help_text="Request body template with placeholders like {input}, {prompt}", + ) + api_response_path = models.CharField( + max_length=255, + blank=True, + help_text="JSON path to extract response (e.g., choices[0].message.content)", + ) + + # ============================================ # Huggingface-specific fields + # ============================================ hf_use_pipeline = models.BooleanField(default=False, help_text="Use Pipeline inference API") hf_auth_token = models.CharField( max_length=255, @@ -142,6 +235,22 @@ class VersionProvider(models.Model): default="flash_attention_2", help_text="Attention Function", ) + hf_trust_remote_code = models.BooleanField( + default=True, + help_text="Trust remote code when loading model", + ) + hf_torch_dtype = models.CharField( + max_length=50, + blank=True, + default="auto", + help_text="Torch dtype (auto, float16, float32, bfloat16)", + ) + hf_device_map = models.CharField( + max_length=50, + blank=True, + default="auto", + help_text="Device map for model loading (auto, cpu, cuda)", + ) framework = models.CharField( max_length=10, choices=AIModelFramework.choices, @@ -150,8 +259,10 @@ class VersionProvider(models.Model): help_text="Framework (PyTorch or TensorFlow)", ) - # Provider-specific configuration - config = models.JSONField(default=dict, help_text="Provider-specific configuration") + # ============================================ + # Provider-specific configuration (catch-all) + # ============================================ + config = models.JSONField(default=dict, help_text="Additional provider-specific configuration") # Timestamps created_at = models.DateTimeField(auto_now_add=True) diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index 1d9b8b7..786ecd0 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -208,11 +208,34 @@ class CreateVersionProviderInput: provider_model_id: Optional[str] = "" is_primary: bool = False is_active: bool = True + + # API Endpoint Configuration + api_endpoint_url: Optional[str] = None + api_http_method: Optional[EndpointHTTPMethodEnum] = None + api_timeout_seconds: int = 60 + + # Authentication Configuration + api_auth_type: Optional[EndpointAuthTypeEnum] = None + api_auth_header_name: str = "Authorization" + api_key: Optional[str] = None + api_key_prefix: str = "Bearer" + + # Request/Response Configuration + api_headers: Optional[strawberry.scalars.JSON] = None + api_request_template: Optional[strawberry.scalars.JSON] = None + api_response_path: Optional[str] = None + + # HuggingFace Configuration hf_use_pipeline: bool = False hf_auth_token: Optional[str] = None hf_model_class: Optional[str] = None hf_attn_implementation: Optional[str] = "flash_attention_2" + hf_trust_remote_code: bool = True + hf_torch_dtype: Optional[str] = "auto" + hf_device_map: Optional[str] = "auto" framework: Optional[str] = None + + # Additional config config: Optional[strawberry.scalars.JSON] = None @@ -224,11 +247,34 @@ class UpdateVersionProviderInput: provider_model_id: Optional[str] = None is_primary: Optional[bool] = None is_active: Optional[bool] = None + + # API Endpoint Configuration + api_endpoint_url: Optional[str] = None + api_http_method: Optional[EndpointHTTPMethodEnum] = None + api_timeout_seconds: Optional[int] = None + + # Authentication Configuration + api_auth_type: Optional[EndpointAuthTypeEnum] = None + api_auth_header_name: Optional[str] = None + api_key: Optional[str] = None + api_key_prefix: Optional[str] = None + + # Request/Response Configuration + api_headers: Optional[strawberry.scalars.JSON] = None + api_request_template: Optional[strawberry.scalars.JSON] = None + api_response_path: Optional[str] = None + + # HuggingFace Configuration hf_use_pipeline: Optional[bool] = None hf_auth_token: Optional[str] = None hf_model_class: Optional[str] = None hf_attn_implementation: Optional[str] = None + hf_trust_remote_code: Optional[bool] = None + hf_torch_dtype: Optional[str] = None + hf_device_map: Optional[str] = None framework: Optional[str] = None + + # Additional config config: Optional[strawberry.scalars.JSON] = None @@ -946,11 +992,29 @@ def create_version_provider( provider_model_id=input.provider_model_id or "", is_primary=input.is_primary, is_active=input.is_active, + # API Endpoint Configuration + api_endpoint_url=input.api_endpoint_url, + api_http_method=input.api_http_method or "POST", + api_timeout_seconds=input.api_timeout_seconds, + # Authentication Configuration + api_auth_type=input.api_auth_type or "BEARER", + api_auth_header_name=input.api_auth_header_name, + api_key=input.api_key, + api_key_prefix=input.api_key_prefix, + # Request/Response Configuration + api_headers=input.api_headers or {}, + api_request_template=input.api_request_template or {}, + api_response_path=input.api_response_path or "", + # HuggingFace Configuration hf_use_pipeline=input.hf_use_pipeline, hf_auth_token=input.hf_auth_token, hf_model_class=input.hf_model_class, hf_attn_implementation=input.hf_attn_implementation or "flash_attention_2", + hf_trust_remote_code=input.hf_trust_remote_code, + hf_torch_dtype=input.hf_torch_dtype or "auto", + hf_device_map=input.hf_device_map or "auto", framework=input.framework, + # Additional config config=input.config or {}, ) @@ -995,6 +1059,34 @@ def update_version_provider( provider.is_primary = input.is_primary if input.is_active is not None: provider.is_active = input.is_active + + # API Endpoint Configuration + if input.api_endpoint_url is not None: + provider.api_endpoint_url = input.api_endpoint_url + if input.api_http_method is not None: + provider.api_http_method = input.api_http_method + if input.api_timeout_seconds is not None: + provider.api_timeout_seconds = input.api_timeout_seconds + + # Authentication Configuration + if input.api_auth_type is not None: + provider.api_auth_type = input.api_auth_type + if input.api_auth_header_name is not None: + provider.api_auth_header_name = input.api_auth_header_name + if input.api_key is not None: + provider.api_key = input.api_key + if input.api_key_prefix is not None: + provider.api_key_prefix = input.api_key_prefix + + # Request/Response Configuration + if input.api_headers is not None: + provider.api_headers = input.api_headers + if input.api_request_template is not None: + provider.api_request_template = input.api_request_template + if input.api_response_path is not None: + provider.api_response_path = input.api_response_path + + # HuggingFace Configuration if input.hf_use_pipeline is not None: provider.hf_use_pipeline = input.hf_use_pipeline if input.hf_auth_token is not None: @@ -1003,8 +1095,16 @@ def update_version_provider( provider.hf_model_class = input.hf_model_class if input.hf_attn_implementation is not None: provider.hf_attn_implementation = input.hf_attn_implementation + if input.hf_trust_remote_code is not None: + provider.hf_trust_remote_code = input.hf_trust_remote_code + if input.hf_torch_dtype is not None: + provider.hf_torch_dtype = input.hf_torch_dtype + if input.hf_device_map is not None: + provider.hf_device_map = input.hf_device_map if input.framework is not None: provider.framework = input.framework + + # Additional config if input.config is not None: provider.config = input.config diff --git a/api/services/model_api_client.py b/api/services/model_api_client.py index 94a39d8..1c4f29e 100644 --- a/api/services/model_api_client.py +++ b/api/services/model_api_client.py @@ -1,9 +1,9 @@ """ API Client for making requests to AI model endpoints. Supports various authentication methods and providers. +Works with VersionProvider configuration. """ -import json import time from typing import Any, Dict, Optional @@ -12,52 +12,62 @@ from django.utils import timezone from tenacity import retry, stop_after_attempt, wait_exponential # type: ignore -from api.models import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import VersionProvider class ModelAPIClient: - """Client for interacting with AI model APIs""" - - def __init__(self, model: AIModel): - self.model = model - self.endpoint = model.get_primary_endpoint() - if not self.endpoint: - raise ValueError(f"No primary endpoint configured for model {model.name}") - - # Get API key if available - self.api_key = self._get_api_key() - - def _get_api_key(self) -> Optional[str]: - """Get the active API key for this model""" - api_key_obj = self.model.api_keys.filter(is_active=True).first() - if api_key_obj: - return api_key_obj.get_key() - return None + """Client for interacting with AI model APIs via VersionProvider configuration.""" + + def __init__(self, provider: VersionProvider): + """ + Initialize the API client with a VersionProvider. + + Args: + provider: VersionProvider instance with API configuration + """ + self.provider = provider + self.version = provider.version + self.model = provider.version.ai_model + + # Validate that we have an API endpoint configured + if not provider.api_endpoint_url: + raise ValueError( + f"No API endpoint URL configured for provider {provider.provider} " + f"on model {self.model.name}" + ) def _build_headers(self) -> Dict[str, str]: - """Build request headers including authentication""" - headers = {"Content-Type": "application/json", **self.endpoint.headers} + """Build request headers including authentication.""" + headers: Dict[str, str] = { + "Content-Type": "application/json", + **(self.provider.api_headers or {}), + } # Add authentication header - if self.api_key and self.endpoint.auth_type != "NONE": - if self.endpoint.auth_type == "BEARER": - headers[self.endpoint.auth_header_name] = f"Bearer {self.api_key}" - elif self.endpoint.auth_type == "API_KEY": - headers[self.endpoint.auth_header_name] = self.api_key - elif self.endpoint.auth_type == "CUSTOM": - # Custom headers should be in endpoint.headers + if self.provider.api_key and self.provider.api_auth_type != "NONE": + auth_value = self.provider.api_key + if self.provider.api_auth_type == "BEARER": + auth_value = f"{self.provider.api_key_prefix} {self.provider.api_key}".strip() + elif self.provider.api_auth_type == "API_KEY": + # Just use the key directly pass + elif self.provider.api_auth_type == "BASIC": + import base64 + + auth_value = f"Basic {base64.b64encode(self.provider.api_key.encode()).decode()}" + + headers[self.provider.api_auth_header_name] = auth_value return headers def _build_request_body( self, input_text: str, parameters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Build request body based on provider and template""" + """Build request body based on provider and template.""" body: Dict[str, Any] - if self.endpoint.request_template: + if self.provider.api_request_template: # Use custom template - template_copy = self.endpoint.request_template.copy() + template_copy = self.provider.api_request_template.copy() # Replace placeholders body = self._replace_placeholders(template_copy, input_text, parameters or {}) else: @@ -93,75 +103,77 @@ def _replace_placeholders(self, template: Dict, input_text: str, parameters: Dic return result def _get_default_template(self, input_text: str, parameters: Dict) -> Dict[str, Any]: - """Get default request template based on provider""" - provider = self.model.provider.upper() + """Get default request template based on provider.""" + provider_type = self.provider.provider.upper() + model_id = self.provider.provider_model_id + max_tokens = self.version.max_tokens or 1000 - if provider == "OPENAI": + if provider_type == "OPENAI": return { - "model": self.model.provider_model_id or "gpt-3.5-turbo", + "model": model_id or "gpt-3.5-turbo", "messages": [{"role": "user", "content": input_text}], "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } - elif "LLAMA" in provider: + elif "LLAMA" in provider_type: # Llama models - format depends on provider - if "OLLAMA" in provider: + if "OLLAMA" in provider_type: return { - "model": self.model.provider_model_id or "llama2", + "model": model_id or "llama2", "prompt": input_text, "stream": False, "options": { "temperature": parameters.get("temperature", 0.7), - "num_predict": parameters.get("max_tokens", self.model.max_tokens or 1000), + "num_predict": parameters.get("max_tokens", max_tokens), }, } - elif "TOGETHER" in provider: + elif "TOGETHER" in provider_type: return { - "model": self.model.provider_model_id or "togethercomputer/llama-2-7b-chat", + "model": model_id or "togethercomputer/llama-2-7b-chat", "prompt": input_text, "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } - elif "REPLICATE" in provider: + elif "REPLICATE" in provider_type: return { - "version": self.model.provider_model_id, + "version": model_id, "input": { "prompt": input_text, "temperature": parameters.get("temperature", 0.7), - "max_length": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_length": parameters.get("max_tokens", max_tokens), }, } else: # Generic Llama format (OpenAI-compatible) return { - "model": self.model.provider_model_id or "llama-2-7b-chat", + "model": model_id or "llama-2-7b-chat", "messages": [{"role": "user", "content": input_text}], "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } else: # Generic template for custom APIs return {"input": input_text, "parameters": parameters} def _extract_response(self, response_data: Dict) -> str: - """Extract text response from API response""" - if self.endpoint.response_path: + """Extract text response from API response.""" + if self.provider.api_response_path: # Use custom response path - result: Any = self._get_nested_value(response_data, self.endpoint.response_path) + result: Any = self._get_nested_value(response_data, self.provider.api_response_path) return str(result) # Default extraction based on provider - provider = self.model.provider.upper() + provider_type = self.provider.provider.upper() try: - if provider == "OPENAI": + if provider_type == "OPENAI": return str(response_data["choices"][0]["message"]["content"]) - elif "LLAMA" in provider: - if "OLLAMA" in provider: + elif "LLAMA" in provider_type: + if "OLLAMA" in provider_type: return str(response_data["response"]) - elif "TOGETHER" in provider: + elif "TOGETHER" in provider_type: return str(response_data["output"]["choices"][0]["text"]) - elif "REPLICATE" in provider: + elif "REPLICATE" in provider_type: # Replicate returns array of strings output = response_data.get("output", []) return "".join(output) if isinstance(output, list) else str(output) @@ -204,23 +216,23 @@ def _get_nested_value(self, data: Dict, path: str) -> Any: return current - async def _update_endpoint_success(self) -> None: - """Update endpoint statistics on successful call (async-safe)""" + async def _update_provider_success(self) -> None: + """Update provider statistics on successful call (async-safe).""" def _update() -> None: - self.endpoint.total_requests += 1 - self.endpoint.last_success_at = timezone.now() - self.endpoint.save(update_fields=["total_requests", "last_success_at"]) + # Update provider's updated_at timestamp + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() - async def _update_endpoint_failure(self) -> None: - """Update endpoint statistics on failed call (async-safe)""" + async def _update_provider_failure(self) -> None: + """Update provider statistics on failed call (async-safe).""" def _update() -> None: - self.endpoint.failed_requests += 1 - self.endpoint.last_failure_at = timezone.now() - self.endpoint.save(update_fields=["failed_requests", "last_failure_at"]) + # Update provider's updated_at timestamp + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() @@ -228,28 +240,30 @@ def _update() -> None: async def call_async( self, input_text: str, parameters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Make async API call to the model""" + """Make async API call to the model via the provider configuration.""" start_time = time.time() headers = self._build_headers() body = self._build_request_body(input_text, parameters) + endpoint_url: str = self.provider.api_endpoint_url # type: ignore[assignment] + try: - async with httpx.AsyncClient(timeout=self.endpoint.timeout_seconds) as client: - if self.endpoint.http_method == "POST": - response = await client.post(self.endpoint.url, headers=headers, json=body) - elif self.endpoint.http_method == "GET": - response = await client.get(self.endpoint.url, headers=headers, params=body) + async with httpx.AsyncClient(timeout=self.provider.api_timeout_seconds) as client: + if self.provider.api_http_method == "POST": + response = await client.post(endpoint_url, headers=headers, json=body) + elif self.provider.api_http_method == "GET": + response = await client.get(endpoint_url, headers=headers, params=body) else: - raise ValueError(f"Unsupported HTTP method: {self.endpoint.http_method}") + raise ValueError(f"Unsupported HTTP method: {self.provider.api_http_method}") response.raise_for_status() response_data = response.json() latency_ms = (time.time() - start_time) * 1000 - # Update endpoint statistics (async-safe) - await self._update_endpoint_success() + # Update provider statistics (async-safe) + await self._update_provider_success() # Extract response text output_text = self._extract_response(response_data) @@ -260,26 +274,32 @@ async def call_async( "raw_response": response_data, "latency_ms": latency_ms, "status_code": response.status_code, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except httpx.HTTPStatusError as e: # Update failure statistics (async-safe) - await self._update_endpoint_failure() + await self._update_provider_failure() return { "success": False, "error": f"HTTP {e.response.status_code}: {e.response.text}", "status_code": e.response.status_code, "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except Exception as e: # Update failure statistics (async-safe) - await self._update_endpoint_failure() + await self._update_provider_failure() return { "success": False, "error": str(e), "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } def call(self, input_text: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: diff --git a/api/services/model_hf_client.py b/api/services/model_hf_client.py index 826d7c7..e28fba9 100644 --- a/api/services/model_hf_client.py +++ b/api/services/model_hf_client.py @@ -1,7 +1,7 @@ """ Hugging Face Client for running local or remote model inference. -Supports both pipeline-based and model-class-based inference, -and integrates with Django model management (AIModel). +Supports both pipeline-based and model-class-based inference. +Works with VersionProvider configuration. """ import time @@ -24,15 +24,36 @@ pipeline, ) -from api.models import AIModel +from api.models.AIModelVersion import VersionProvider class ModelHFClient: - """Client for interacting with Hugging Face models.""" + """Client for interacting with Hugging Face models via VersionProvider configuration.""" - def __init__(self, model: AIModel): - self.model = model + def __init__(self, provider: VersionProvider): + """ + Initialize the HuggingFace client with a VersionProvider. + + Args: + provider: VersionProvider instance with HuggingFace configuration + """ + self.provider = provider + self.version = provider.version + self.model = provider.version.ai_model self.device = self._get_device() + + # Validate provider type + if provider.provider != "HUGGINGFACE": + raise ValueError( + f"ModelHFClient requires HUGGINGFACE provider, got {provider.provider}" + ) + + # Validate model ID is set + if not provider.provider_model_id: + raise ValueError( + f"No provider_model_id configured for HuggingFace provider on model {self.model.name}" + ) + self.model_map = { "AutoModelForCausalLM": AutoModelForCausalLM, "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, @@ -60,37 +81,58 @@ def _get_device(self) -> int: """Select device (0 for GPU if available, else CPU).""" return 0 if torch.cuda.is_available() else -1 + def _get_torch_dtype(self) -> Any: + """Get torch dtype based on provider configuration.""" + dtype_map = { + "auto": "auto", + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + dtype_str = self.provider.hf_torch_dtype or "auto" + if dtype_str == "auto": + return torch.float16 if torch.cuda.is_available() else torch.float32 + return dtype_map.get(dtype_str, torch.float32) + def _load_pipeline(self) -> Any: """Initialize a Hugging Face pipeline.""" + framework = self.provider.framework or "pt" return pipeline( task=self.task_map.get(self.model.model_type, "text-generation"), - model=self.model.provider_model_id, - framework="pt", + model=self.provider.provider_model_id, + framework=framework, device=self.device, - trust_remote_code=True, - use_auth_token=self.model.hf_auth_token or None, + trust_remote_code=self.provider.hf_trust_remote_code, + token=self.provider.hf_auth_token or None, ) def _load_model_and_tokenizer(self) -> Tuple[Any, Any]: """Load model and tokenizer for manual inference.""" tokenizer = AutoTokenizer.from_pretrained( - self.model.provider_model_id, - trust_remote_code=True, - use_auth_token=self.model.hf_auth_token or None, + self.provider.provider_model_id, + trust_remote_code=self.provider.hf_trust_remote_code, + token=self.provider.hf_auth_token or None, ) model_class = self.model_map.get( - self.model.hf_model_class or "AutoModelForCausalLM", AutoModelForCausalLM - ) - model = model_class.from_pretrained( - pretrained_model_name_or_path=self.model.provider_model_id, - trust_remote_code=True, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - use_auth_token=self.model.hf_auth_token or None, - attn_implementation=self.model.hf_attn_implementation, - device_map="auto", + self.provider.hf_model_class or "AutoModelForCausalLM", AutoModelForCausalLM ) + # Build model loading kwargs + model_kwargs: Dict[str, Any] = { + "pretrained_model_name_or_path": self.provider.provider_model_id, + "trust_remote_code": self.provider.hf_trust_remote_code, + "torch_dtype": self._get_torch_dtype(), + "token": self.provider.hf_auth_token or None, + "device_map": self.provider.hf_device_map or "auto", + } + + # Add attention implementation if specified + if self.provider.hf_attn_implementation: + model_kwargs["attn_implementation"] = self.provider.hf_attn_implementation + + model = model_class.from_pretrained(**model_kwargs) + return model, tokenizer def _generate_from_model(self, model: Any, tokenizer: Any, input_text: str) -> str: @@ -109,26 +151,20 @@ def _generate_from_model(self, model: Any, tokenizer: Any, input_text: str) -> s return decoded_text async def _update_success(self) -> None: - """Update endpoint statistics for successful inference.""" + """Update provider statistics for successful inference.""" def _update() -> None: - endpoint = self.model.get_primary_endpoint() - if endpoint: - endpoint.total_requests += 1 - endpoint.last_success_at = timezone.now() - endpoint.save(update_fields=["total_requests", "last_success_at"]) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() async def _update_failure(self) -> None: - """Update endpoint statistics for failed inference.""" + """Update provider statistics for failed inference.""" def _update() -> None: - endpoint = self.model.get_primary_endpoint() - if endpoint: - endpoint.failed_requests += 1 - endpoint.last_failure_at = timezone.now() - endpoint.save(update_fields=["failed_requests", "last_failure_at"]) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() @@ -136,11 +172,11 @@ def _update() -> None: async def call_async(self, input_text: str) -> Dict[str, Any]: """ Run asynchronous inference on the model. - Supports both pipeline and manual model modes. + Supports both pipeline and manual model modes based on provider configuration. """ start_time = time.time() try: - if self.model.hf_use_pipeline: + if self.provider.hf_use_pipeline: pipe = self._load_pipeline() result = pipe(input_text) output_text = result if isinstance(result, str) else str(result) @@ -155,8 +191,8 @@ async def call_async(self, input_text: str) -> Dict[str, Any]: "success": True, "output": output_text, "latency_ms": latency_ms, - "provider": "HuggingFace", - "model_id": self.model.provider_model_id, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except Exception as e: @@ -165,8 +201,8 @@ async def call_async(self, input_text: str) -> Dict[str, Any]: "success": False, "error": str(e), "latency_ms": (time.time() - start_time) * 1000, - "provider": "HuggingFace", - "model_id": self.model.provider_model_id, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } def call(self, input_text: str) -> Dict[str, Any]: diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 6105bb4..7adf63e 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -227,11 +227,34 @@ class TypeVersionProvider(BaseType): provider_model_id: Optional[str] is_primary: bool is_active: bool + + # API Endpoint Configuration + api_endpoint_url: Optional[str] + api_http_method: EndpointHTTPMethodEnum + api_timeout_seconds: int + + # Authentication Configuration + api_auth_type: EndpointAuthTypeEnum + api_auth_header_name: str + api_key: Optional[str] + api_key_prefix: Optional[str] + + # Request/Response Configuration + api_headers: strawberry.scalars.JSON + api_request_template: strawberry.scalars.JSON + api_response_path: Optional[str] + + # HuggingFace Configuration hf_use_pipeline: bool hf_auth_token: Optional[str] hf_model_class: Optional[str] hf_attn_implementation: Optional[str] + hf_trust_remote_code: bool + hf_torch_dtype: Optional[str] + hf_device_map: Optional[str] framework: Optional[str] + + # Additional config config: strawberry.scalars.JSON created_at: datetime updated_at: datetime From 3512b4cfb4cd24b083b372b9d260324067517b07 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 31 Dec 2025 11:49:46 +0530 Subject: [PATCH 041/127] update aimodel document and search api --- dataspace_sdk/resources/aimodels.py | 141 ++++++++++++++++++++++++++- search/documents/aimodel_document.py | 18 ++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 57536ba..8dda93c 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -135,10 +135,25 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: providerModelId isPrimary isActive + # API Configuration + apiEndpointUrl + apiHttpMethod + apiTimeoutSeconds + apiAuthType + apiAuthHeaderName + apiKey + apiKeyPrefix + apiHeaders + apiRequestTemplate + apiResponsePath + # HuggingFace Configuration hfUsePipeline hfAuthToken hfModelClass hfAttnImplementation + hfTrustRemoteCode + hfTorchDtype + hfDeviceMap framework config } @@ -589,10 +604,25 @@ def get_version_providers(self, version_id: int) -> List[Dict[str, Any]]: providerModelId isPrimary isActive + # API Configuration + apiEndpointUrl + apiHttpMethod + apiTimeoutSeconds + apiAuthType + apiAuthHeaderName + apiKey + apiKeyPrefix + apiHeaders + apiRequestTemplate + apiResponsePath + # HuggingFace Configuration hfUsePipeline hfAuthToken hfModelClass hfAttnImplementation + hfTrustRemoteCode + hfTorchDtype + hfDeviceMap framework config } @@ -623,10 +653,25 @@ def create_provider( provider: str, provider_model_id: str, is_primary: bool = False, + # API Configuration + api_endpoint_url: Optional[str] = None, + api_http_method: str = "POST", + api_timeout_seconds: int = 60, + api_auth_type: str = "BEARER", + api_auth_header_name: str = "Authorization", + api_key: Optional[str] = None, + api_key_prefix: str = "Bearer", + api_headers: Optional[Dict[str, str]] = None, + api_request_template: Optional[Dict[str, Any]] = None, + api_response_path: Optional[str] = None, + # HuggingFace Configuration hf_use_pipeline: bool = False, hf_model_class: Optional[str] = None, hf_auth_token: Optional[str] = None, hf_attn_implementation: Optional[str] = None, + hf_trust_remote_code: bool = True, + hf_torch_dtype: Optional[str] = "auto", + hf_device_map: Optional[str] = "auto", framework: Optional[str] = None, config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -639,12 +684,25 @@ def create_provider( LLAMA_CUSTOM, CUSTOM, HUGGINGFACE) provider_model_id: Model ID at the provider (e.g., "gpt-4", "meta-llama/Llama-2-7b") is_primary: Whether this is the primary provider + api_endpoint_url: Full URL for the API endpoint + api_http_method: HTTP method (POST, GET) + api_timeout_seconds: Request timeout in seconds + api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE) + api_auth_header_name: Header name for authentication + api_key: API key or token + api_key_prefix: Prefix for the API key (e.g., "Bearer") + api_headers: Additional headers as dict + api_request_template: Request body template as dict + api_response_path: JSON path to extract response text hf_use_pipeline: For HuggingFace - whether to use pipeline API hf_model_class: For HuggingFace - model class (e.g., "AutoModelForCausalLM") hf_auth_token: For HuggingFace - auth token for gated models hf_attn_implementation: For HuggingFace - attention implementation + hf_trust_remote_code: For HuggingFace - trust remote code + hf_torch_dtype: For HuggingFace - torch dtype (auto, float16, bfloat16) + hf_device_map: For HuggingFace - device map (auto, cuda, cpu) framework: Framework (pt, tf) - config: Additional configuration (apiKey, baseUrl, authType, etc.) + config: Additional configuration Returns: Dictionary containing created provider information @@ -670,15 +728,40 @@ def create_provider( "provider": provider, "providerModelId": provider_model_id, "isPrimary": is_primary, + # API Configuration + "apiHttpMethod": api_http_method, + "apiTimeoutSeconds": api_timeout_seconds, + "apiAuthType": api_auth_type, + "apiAuthHeaderName": api_auth_header_name, + "apiKeyPrefix": api_key_prefix, + # HuggingFace Configuration "hfUsePipeline": hf_use_pipeline, + "hfTrustRemoteCode": hf_trust_remote_code, } + # Optional API fields + if api_endpoint_url: + input_data["apiEndpointUrl"] = api_endpoint_url + if api_key: + input_data["apiKey"] = api_key + if api_headers: + input_data["apiHeaders"] = api_headers + if api_request_template: + input_data["apiRequestTemplate"] = api_request_template + if api_response_path: + input_data["apiResponsePath"] = api_response_path + + # Optional HuggingFace fields if hf_model_class: input_data["hfModelClass"] = hf_model_class if hf_auth_token: input_data["hfAuthToken"] = hf_auth_token if hf_attn_implementation: input_data["hfAttnImplementation"] = hf_attn_implementation + if hf_torch_dtype: + input_data["hfTorchDtype"] = hf_torch_dtype + if hf_device_map: + input_data["hfDeviceMap"] = hf_device_map if framework: input_data["framework"] = framework if config: @@ -705,10 +788,25 @@ def update_provider( provider_id: int, provider_model_id: Optional[str] = None, is_primary: Optional[bool] = None, + # API Configuration + api_endpoint_url: Optional[str] = None, + api_http_method: Optional[str] = None, + api_timeout_seconds: Optional[int] = None, + api_auth_type: Optional[str] = None, + api_auth_header_name: Optional[str] = None, + api_key: Optional[str] = None, + api_key_prefix: Optional[str] = None, + api_headers: Optional[Dict[str, str]] = None, + api_request_template: Optional[Dict[str, Any]] = None, + api_response_path: Optional[str] = None, + # HuggingFace Configuration hf_use_pipeline: Optional[bool] = None, hf_model_class: Optional[str] = None, hf_auth_token: Optional[str] = None, hf_attn_implementation: Optional[str] = None, + hf_trust_remote_code: Optional[bool] = None, + hf_torch_dtype: Optional[str] = None, + hf_device_map: Optional[str] = None, framework: Optional[str] = None, config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -719,10 +817,23 @@ def update_provider( provider_id: ID of the provider to update provider_model_id: New model ID at the provider is_primary: Whether this is the primary provider + api_endpoint_url: Full URL for the API endpoint + api_http_method: HTTP method (POST, GET) + api_timeout_seconds: Request timeout in seconds + api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE) + api_auth_header_name: Header name for authentication + api_key: API key or token + api_key_prefix: Prefix for the API key (e.g., "Bearer") + api_headers: Additional headers as dict + api_request_template: Request body template as dict + api_response_path: JSON path to extract response text hf_use_pipeline: For HuggingFace - whether to use pipeline API hf_model_class: For HuggingFace - model class hf_auth_token: For HuggingFace - auth token hf_attn_implementation: For HuggingFace - attention implementation + hf_trust_remote_code: For HuggingFace - trust remote code + hf_torch_dtype: For HuggingFace - torch dtype + hf_device_map: For HuggingFace - device map framework: Framework (pt, tf) config: Additional configuration @@ -751,6 +862,28 @@ def update_provider( input_data["providerModelId"] = provider_model_id if is_primary is not None: input_data["isPrimary"] = is_primary + # API Configuration + if api_endpoint_url is not None: + input_data["apiEndpointUrl"] = api_endpoint_url + if api_http_method is not None: + input_data["apiHttpMethod"] = api_http_method + if api_timeout_seconds is not None: + input_data["apiTimeoutSeconds"] = api_timeout_seconds + if api_auth_type is not None: + input_data["apiAuthType"] = api_auth_type + if api_auth_header_name is not None: + input_data["apiAuthHeaderName"] = api_auth_header_name + if api_key is not None: + input_data["apiKey"] = api_key + if api_key_prefix is not None: + input_data["apiKeyPrefix"] = api_key_prefix + if api_headers is not None: + input_data["apiHeaders"] = api_headers + if api_request_template is not None: + input_data["apiRequestTemplate"] = api_request_template + if api_response_path is not None: + input_data["apiResponsePath"] = api_response_path + # HuggingFace Configuration if hf_use_pipeline is not None: input_data["hfUsePipeline"] = hf_use_pipeline if hf_model_class is not None: @@ -759,6 +892,12 @@ def update_provider( input_data["hfAuthToken"] = hf_auth_token if hf_attn_implementation is not None: input_data["hfAttnImplementation"] = hf_attn_implementation + if hf_trust_remote_code is not None: + input_data["hfTrustRemoteCode"] = hf_trust_remote_code + if hf_torch_dtype is not None: + input_data["hfTorchDtype"] = hf_torch_dtype + if hf_device_map is not None: + input_data["hfDeviceMap"] = hf_device_map if framework is not None: input_data["framework"] = framework if config is not None: diff --git a/search/documents/aimodel_document.py b/search/documents/aimodel_document.py index 4566de5..4d5ae95 100644 --- a/search/documents/aimodel_document.py +++ b/search/documents/aimodel_document.py @@ -147,6 +147,15 @@ class AIModelDocument(Document): "provider_model_id": fields.KeywordField(), "is_primary": fields.BooleanField(), "is_active": fields.BooleanField(), + # API Configuration + "api_endpoint_url": fields.KeywordField(), + "api_http_method": fields.KeywordField(), + "api_timeout_seconds": fields.IntegerField(), + "api_auth_type": fields.KeywordField(), + # HuggingFace Configuration + "hf_use_pipeline": fields.BooleanField(), + "hf_model_class": fields.KeywordField(), + "framework": fields.KeywordField(), } ), } @@ -222,6 +231,15 @@ def prepare_versions(self, instance: AIModel) -> List[Dict[str, Any]]: "provider_model_id": provider_obj.provider_model_id, "is_primary": provider_obj.is_primary, "is_active": provider_obj.is_active, + # API Configuration + "api_endpoint_url": provider_obj.api_endpoint_url, + "api_http_method": provider_obj.api_http_method, + "api_timeout_seconds": provider_obj.api_timeout_seconds, + "api_auth_type": provider_obj.api_auth_type, + # HuggingFace Configuration + "hf_use_pipeline": provider_obj.hf_use_pipeline, + "hf_model_class": provider_obj.hf_model_class, + "framework": provider_obj.framework, } ) versions_data.append( From 215d6857f811b9b86e1af2d76f1741e055ba5127 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 5 Jan 2026 08:46:16 +0000 Subject: [PATCH 042/127] Bump SDK version to 0.4.1 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 16ce66b..7a1746b 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/pyproject.toml b/pyproject.toml index 97ef11a..99f6e2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.0" +version = "0.4.1" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.0" +python_version = "0.4.1" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 878847c574267c273dd62a5cedaf8f45bfd9d3a5 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 5 Jan 2026 14:36:08 +0530 Subject: [PATCH 043/127] use primary version always when calling call_model --- api/views/aimodel_execution.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/api/views/aimodel_execution.py b/api/views/aimodel_execution.py index 9c65d01..b5dc0d8 100644 --- a/api/views/aimodel_execution.py +++ b/api/views/aimodel_execution.py @@ -3,9 +3,8 @@ Handles model inference requests via ModelAPIClient and ModelHFClient. """ -from typing import Any - import logging +from typing import Any from rest_framework import status from rest_framework.decorators import api_view, permission_classes @@ -67,13 +66,34 @@ def call_aimodel(request: Request, model_id: str) -> Response: parameters = request.data.get("parameters", {}) + # Get the primary version and provider + primary_version = model.versions.filter(is_latest=True).first() + if not primary_version: + primary_version = model.versions.first() + + if not primary_version: + return Response( + {"error": "No version found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + primary_provider = primary_version.providers.filter(is_primary=True, is_active=True).first() + if not primary_provider: + primary_provider = primary_version.providers.filter(is_active=True).first() + + if not primary_provider: + return Response( + {"error": "No active provider found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) + # Route to appropriate client based on provider result: Any - if model.provider == "HUGGINGFACE": - hf_client = ModelHFClient(model) + if primary_provider.provider == "HUGGINGFACE": + hf_client = ModelHFClient(primary_provider) result = hf_client.call(input_text) else: - api_client = ModelAPIClient(model) + api_client = ModelAPIClient(primary_provider) result = api_client.call(input_text, parameters) return Response(result, status=status.HTTP_200_OK) From 0cb6fc6c4d0f5c2825d2d404e3c07e12f41cb9d2 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 5 Jan 2026 15:27:59 +0530 Subject: [PATCH 044/127] add expample template and support model_id replacement --- api/services/model_api_client.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/api/services/model_api_client.py b/api/services/model_api_client.py index 1c4f29e..ca2851f 100644 --- a/api/services/model_api_client.py +++ b/api/services/model_api_client.py @@ -81,8 +81,22 @@ def _replace_placeholders(self, template: Dict, input_text: str, parameters: Dic result: Dict[str, Any] = {} for key, value in template.items(): if isinstance(value, str): + # template = { + # "model": "{model_id}", + # "messages": [{"role": "user", "content": "{input}"}] + # "temperature": {temperature}, + # "max_tokens": {max_tokens} + # } + # parameters = { + # "temperature": 0.7, + # "max_tokens": 1000 + # } + # replaced_value = value.replace("{input}", input_text) replaced_value = replaced_value.replace("{prompt}", input_text) + replaced_value = replaced_value.replace( + "{model_id}", self.provider.provider_model_id + ) # Replace parameter placeholders for param_key, param_value in parameters.items(): replaced_value = replaced_value.replace(f"{{{param_key}}}", str(param_value)) From 07670698ba6cb8ff04ca6296c1b9531d3c4fcf96 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 5 Jan 2026 17:09:19 +0530 Subject: [PATCH 045/127] check for content type in audit response --- api/services/model_api_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/api/services/model_api_client.py b/api/services/model_api_client.py index ca2851f..3c7bca2 100644 --- a/api/services/model_api_client.py +++ b/api/services/model_api_client.py @@ -272,6 +272,15 @@ async def call_async( raise ValueError(f"Unsupported HTTP method: {self.provider.api_http_method}") response.raise_for_status() + + # Check if response is JSON + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + raise ValueError( + f"Expected JSON response but got {content_type}. " + f"Check if the API endpoint URL is correct: {endpoint_url}" + ) + response_data = response.json() latency_ms = (time.time() - start_time) * 1000 From bca505a484157636daf1eb548a4b4d3aed40b494 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 6 Jan 2026 13:07:29 +0530 Subject: [PATCH 046/127] change ai model api call to synchronous call --- api/services/model_api_client.py | 77 ++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/api/services/model_api_client.py b/api/services/model_api_client.py index 3c7bca2..4c84b03 100644 --- a/api/services/model_api_client.py +++ b/api/services/model_api_client.py @@ -326,20 +326,69 @@ async def call_async( } def call(self, input_text: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Make synchronous API call to the model""" - import asyncio + """Make synchronous API call to the model using httpx sync client.""" + start_time = time.time() + + headers = self._build_headers() + body = self._build_request_body(input_text, parameters) + endpoint_url: str = self.provider.api_endpoint_url # type: ignore[assignment] - # Run async call in sync context try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # If loop is already running, use nest_asyncio - import nest_asyncio # type: ignore + with httpx.Client(timeout=self.provider.api_timeout_seconds) as client: + if self.provider.api_http_method == "POST": + response = client.post(endpoint_url, headers=headers, json=body) + elif self.provider.api_http_method == "GET": + response = client.get(endpoint_url, headers=headers, params=body) + else: + raise ValueError(f"Unsupported HTTP method: {self.provider.api_http_method}") - nest_asyncio.apply() - return loop.run_until_complete(self.call_async(input_text, parameters)) - else: - return asyncio.run(self.call_async(input_text, parameters)) - except RuntimeError: - # No event loop, create one - return asyncio.run(self.call_async(input_text, parameters)) + response.raise_for_status() + + # Check if response is JSON + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + raise ValueError( + f"Expected JSON response but got {content_type}. " + f"Check if the API endpoint URL is correct: {endpoint_url}" + ) + + response_data = response.json() + latency_ms = (time.time() - start_time) * 1000 + + # Update provider timestamp (sync) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) + + # Extract response text + output_text = self._extract_response(response_data) + + return { + "success": True, + "output": output_text, + "raw_response": response_data, + "latency_ms": latency_ms, + "status_code": response.status_code, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } + + except httpx.HTTPStatusError as e: + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) + + return { + "success": False, + "error": f"HTTP {e.response.status_code}: {e.response.text}", + "status_code": e.response.status_code, + "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } From 45315f3ce43fafca1a78d424081b1e8b1fa007ef Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 6 Jan 2026 21:15:33 +0530 Subject: [PATCH 047/127] add sector client to sdk --- api/types/type_sector.py | 20 ++++- dataspace_sdk/client.py | 2 + dataspace_sdk/resources/__init__.py | 3 +- dataspace_sdk/resources/sectors.py | 128 ++++++++++++++++++++++++++++ 4 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 dataspace_sdk/resources/sectors.py diff --git a/api/types/type_sector.py b/api/types/type_sector.py index 40e2d89..4ffee0e 100644 --- a/api/types/type_sector.py +++ b/api/types/type_sector.py @@ -9,7 +9,7 @@ from api.models import Sector from api.types.base_type import BaseType -from api.utils.enums import DatasetStatus +from api.utils.enums import AIModelStatus, DatasetStatus @strawberry.enum @@ -54,6 +54,24 @@ def min_dataset_count(self, queryset: Any, value: Optional[int], prefix: str) -> # Return queryset with filter return queryset, Q(**{f"{prefix}_dataset_count__gte": value}) + @strawberry_django.filter_field + def min_aimodel_count(self, queryset: Any, value: Optional[int], prefix: str) -> tuple[Any, Q]: # type: ignore + # Skip filtering if no value provided + if value is None: + return queryset, Q() + + # Annotate queryset with dataset count + queryset = queryset.annotate( + _aimodel_count=Count( + "ai_models", + filter=Q(ai_models__status=AIModelStatus.ACTIVE), + distinct=True, + ) + ) + + # Return queryset with filter + return queryset, Q(**{f"{prefix}_aimodel_count__gte": value}) + @strawberry_django.filter_field def max_dataset_count(self, queryset: Any, value: Optional[int], prefix: str) -> tuple[Any, Q]: # type: ignore # Skip filtering if no value provided diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index a2c3674..5d1dfce 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -5,6 +5,7 @@ from dataspace_sdk.auth import AuthClient from dataspace_sdk.resources.aimodels import AIModelClient from dataspace_sdk.resources.datasets import DatasetClient +from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient @@ -64,6 +65,7 @@ def __init__( self.datasets = DatasetClient(self.base_url, self._auth) self.aimodels = AIModelClient(self.base_url, self._auth) self.usecases = UseCaseClient(self.base_url, self._auth) + self.sectors = SectorClient(self.base_url, self._auth) def login(self, username: str, password: str) -> dict: """ diff --git a/dataspace_sdk/resources/__init__.py b/dataspace_sdk/resources/__init__.py index dc02cf6..aa742c7 100644 --- a/dataspace_sdk/resources/__init__.py +++ b/dataspace_sdk/resources/__init__.py @@ -2,6 +2,7 @@ from dataspace_sdk.resources.aimodels import AIModelClient from dataspace_sdk.resources.datasets import DatasetClient +from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient -__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient"] +__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient", "SectorClient"] diff --git a/dataspace_sdk/resources/sectors.py b/dataspace_sdk/resources/sectors.py new file mode 100644 index 0000000..c1fec70 --- /dev/null +++ b/dataspace_sdk/resources/sectors.py @@ -0,0 +1,128 @@ +"""Sector resource client for DataSpace SDK.""" + +from typing import Any, Dict, List, Optional + +from dataspace_sdk.base import BaseAPIClient + + +class SectorClient(BaseAPIClient): + """Client for interacting with Sector resources.""" + + def list_all( + self, + search: Optional[str] = None, + min_dataset_count: Optional[int] = None, + min_aimodel_count: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """ + List all sectors with optional filters using GraphQL. + + Args: + search: Search query for name/description + min_dataset_count: Filter sectors with at least this many published datasets + min_aimodel_count: Filter sectors with at least this many active AI models + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of sector dictionaries + """ + query = """ + query ListSectors($filters: SectorFilter, $pagination: OffsetPaginationInput) { + sectors(filters: $filters, pagination: $pagination) { + id + name + slug + description + } + } + """ + + filters: Dict[str, Any] = {} + if search: + filters["search"] = search + if min_dataset_count is not None: + filters["min_dataset_count"] = min_dataset_count + if min_aimodel_count is not None: + filters["min_aimodel_count"] = min_aimodel_count + + variables: Dict[str, Any] = { + "pagination": {"limit": limit, "offset": offset}, + } + if filters: + variables["filters"] = filters + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": variables, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + data = response.get("data", {}) + sectors_result: List[Dict[str, Any]] = ( + data.get("sectors", []) if isinstance(data, dict) else [] + ) + return sectors_result + + def get_by_id(self, sector_id: str) -> Dict[str, Any]: + """ + Get a sector by ID using GraphQL. + + Args: + sector_id: UUID of the sector + + Returns: + Dictionary containing sector information + """ + query = """ + query GetSector($id: UUID!) { + sector(id: $id) { + id + name + slug + description + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"id": sector_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("sector", {}) + return result + + def get_sectors_with_aimodels( + self, + limit: int = 100, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """ + Get sectors that have at least one active AI model. + + Args: + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of sector dictionaries with AI models + """ + return self.list_all(min_aimodel_count=1, limit=limit, offset=offset) From 05ce67ef92e7731627d44b9017efbdb85a826947 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 6 Jan 2026 15:49:12 +0000 Subject: [PATCH 048/127] Bump SDK version to 0.4.2 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 7a1746b..b5b25d0 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/pyproject.toml b/pyproject.toml index 99f6e2c..99ce165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.1" +version = "0.4.2" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.1" +python_version = "0.4.2" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 44e975d15d912939dc77d17c05f220af325b23e3 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 6 Jan 2026 22:23:34 +0530 Subject: [PATCH 049/127] add sectors and geographies to aimodel in sdk --- dataspace_sdk/resources/aimodels.py | 9 +++++++++ dataspace_sdk/resources/sectors.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 8dda93c..93b04c1 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -219,6 +219,15 @@ def list_all( id value } + sectors { + id + name + slug + } + geographies { + id + name + } versions { id version diff --git a/dataspace_sdk/resources/sectors.py b/dataspace_sdk/resources/sectors.py index c1fec70..0436f18 100644 --- a/dataspace_sdk/resources/sectors.py +++ b/dataspace_sdk/resources/sectors.py @@ -44,9 +44,9 @@ def list_all( if search: filters["search"] = search if min_dataset_count is not None: - filters["min_dataset_count"] = min_dataset_count + filters["minDatasetCount"] = min_dataset_count if min_aimodel_count is not None: - filters["min_aimodel_count"] = min_aimodel_count + filters["minAimodelCount"] = min_aimodel_count variables: Dict[str, Any] = { "pagination": {"limit": limit, "offset": offset}, From a5d244096bd8c51e6fc533e02d5277caa8c09706 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 6 Jan 2026 16:55:30 +0000 Subject: [PATCH 050/127] Bump SDK version to 0.4.3 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index b5b25d0..49f6723 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.2" +__version__ = "0.4.3" diff --git a/pyproject.toml b/pyproject.toml index 99ce165..c70875c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.2" +version = "0.4.3" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.2" +python_version = "0.4.3" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From c123f5a0a8e771973089f41ac2ccd32c2dc9dac2 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 7 Jan 2026 11:49:54 +0530 Subject: [PATCH 051/127] add aimodel count to sector type --- api/types/type_sector.py | 4 ++++ dataspace_sdk/resources/sectors.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/api/types/type_sector.py b/api/types/type_sector.py index 4ffee0e..c4f801b 100644 --- a/api/types/type_sector.py +++ b/api/types/type_sector.py @@ -128,3 +128,7 @@ class TypeSector(BaseType): @strawberry.field def dataset_count(self: Any) -> int: return int(self.datasets.filter(status=DatasetStatus.PUBLISHED).count()) + + @strawberry.field + def aimodel_count(self: Any) -> int: + return int(self.ai_models.filter(status=AIModelStatus.ACTIVE).count()) diff --git a/dataspace_sdk/resources/sectors.py b/dataspace_sdk/resources/sectors.py index 0436f18..e5e0700 100644 --- a/dataspace_sdk/resources/sectors.py +++ b/dataspace_sdk/resources/sectors.py @@ -36,6 +36,8 @@ def list_all( name slug description + datasetCount + aimodelCount } } """ @@ -90,6 +92,8 @@ def get_by_id(self, sector_id: str) -> Dict[str, Any]: name slug description + datasetCount + aimodelCount } } """ From d32f99f7a724b9067117d8cdf89ecfddb021085f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 7 Jan 2026 06:25:18 +0000 Subject: [PATCH 052/127] Bump SDK version to 0.4.4 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 49f6723..3e6b331 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.3" +__version__ = "0.4.4" diff --git a/pyproject.toml b/pyproject.toml index c70875c..0a5fdaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.3" +version = "0.4.4" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.3" +python_version = "0.4.4" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 65707d9285323f7a41c223d80b8909f6c61fb08f Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 12 Jan 2026 18:21:55 +0530 Subject: [PATCH 053/127] add prompt dataset type as child type of dataset --- api/models/Dataset.py | 25 +-- api/models/PromptDataset.py | 113 +++++++++++ api/models/__init__.py | 1 + api/schema/dataset_schema.py | 275 +++++++++++++++++---------- api/types/type_dataset.py | 39 +++- api/types/type_prompt_metadata.py | 61 ++++++ api/utils/enums.py | 22 +++ api/views/search_dataset.py | 93 +++++---- dataspace_sdk/resources/datasets.py | 198 +++++++++++++++++++ search/documents/dataset_document.py | 64 +++++-- 10 files changed, 730 insertions(+), 161 deletions(-) create mode 100644 api/models/PromptDataset.py create mode 100644 api/types/type_prompt_metadata.py diff --git a/api/models/Dataset.py b/api/models/Dataset.py index 4981cdd..664fb81 100644 --- a/api/models/Dataset.py +++ b/api/models/Dataset.py @@ -5,7 +5,12 @@ from django.db.models import Sum from django.utils.text import slugify -from api.utils.enums import DatasetAccessType, DatasetLicense, DatasetStatus +from api.utils.enums import ( + DatasetAccessType, + DatasetLicense, + DatasetStatus, + DatasetType, +) if TYPE_CHECKING: from api.models.DataSpace import DataSpace @@ -59,9 +64,7 @@ class Dataset(models.Model): max_length=50, default=DatasetStatus.DRAFT, choices=DatasetStatus.choices ) sectors = models.ManyToManyField("api.Sector", blank=True, related_name="datasets") - geographies = models.ManyToManyField( - "api.Geography", blank=True, related_name="datasets" - ) + geographies = models.ManyToManyField("api.Geography", blank=True, related_name="datasets") access_type = models.CharField( max_length=50, default=DatasetAccessType.PUBLIC, @@ -72,6 +75,11 @@ class Dataset(models.Model): default=DatasetLicense.CC_BY_4_0_ATTRIBUTION, choices=DatasetLicense.choices, ) + dataset_type = models.CharField( + max_length=50, + default=DatasetType.DATA, + choices=DatasetType.choices, + ) def save(self, *args: Any, **kwargs: Any) -> None: if not self.slug: @@ -138,10 +146,7 @@ def has_charts(self) -> bool: @property def download_count(self) -> int: return ( - self.resources.aggregate(total_downloads=Sum("download_count"))[ - "total_downloads" - ] - or 0 + self.resources.aggregate(total_downloads=Sum("download_count"))["total_downloads"] or 0 ) @property @@ -176,9 +181,7 @@ def trending_score(self) -> float: return float(base_score) * 0.1 # Calculate recency factor (more recent = higher score) - recent_downloads = ( - recent_resources.aggregate(total=Sum("download_count"))["total"] or 0 - ) + recent_downloads = recent_resources.aggregate(total=Sum("download_count"))["total"] or 0 # Calculate trending score: base score + (recent downloads * recency factor) recency_factor = 2.0 # Weight for recent downloads diff --git a/api/models/PromptDataset.py b/api/models/PromptDataset.py new file mode 100644 index 0000000..0946059 --- /dev/null +++ b/api/models/PromptDataset.py @@ -0,0 +1,113 @@ +"""PromptDataset model - extends Dataset with prompt-specific fields.""" + +from django.db import models + +from api.models.Dataset import Dataset +from api.utils.enums import DatasetType, PromptTaskType + + +class PromptDataset(Dataset): + """ + PromptDataset extends Dataset with prompt-specific fields. + + Uses Django multi-table inheritance - PromptDataset has all fields + from Dataset plus additional prompt-specific fields. The parent + Dataset is automatically created and linked via a OneToOne relationship. + + This means PromptDataset: + - Has all Dataset fields (title, description, tags, sectors, etc.) + - Can have DatasetMetadata entries (via inherited relationship) + - Can have Resources (prompt files instead of data files) + - Has additional prompt-specific fields below + """ + + # Prompt task type (e.g., text generation, classification, etc.) + task_type = models.CharField( + max_length=100, + choices=PromptTaskType.choices, + blank=True, + null=True, + ) + + # Target language(s) for the prompts + target_languages = models.JSONField( + blank=True, + null=True, + help_text="List of target languages for the prompts (e.g., ['en', 'hi', 'ta'])", + ) + + # Domain/category of prompts + domain = models.CharField( + max_length=200, + blank=True, + null=True, + help_text="Domain or category (e.g., healthcare, education, legal)", + ) + + # Target AI model types + target_model_types = models.JSONField( + blank=True, + null=True, + help_text="List of AI model types these prompts are designed for", + ) + + # Prompt format/template information + prompt_format = models.CharField( + max_length=100, + blank=True, + null=True, + help_text="Format of prompts (e.g., instruction, chat, completion)", + ) + + # Whether prompts include system instructions + has_system_prompt = models.BooleanField( + default=False, + help_text="Whether the prompts include system-level instructions", + ) + + # Whether prompts include example responses + has_example_responses = models.BooleanField( + default=False, + help_text="Whether the prompts include example/expected responses", + ) + + # Average prompt length (for filtering/search) + avg_prompt_length = models.IntegerField( + blank=True, + null=True, + help_text="Average character length of prompts in this dataset", + ) + + # Number of prompts in the dataset + prompt_count = models.IntegerField( + blank=True, + null=True, + help_text="Total number of prompts in this dataset", + ) + + # Use case description + use_case = models.TextField( + blank=True, + null=True, + help_text="Description of intended use cases for these prompts", + ) + + # Evaluation criteria or metrics + evaluation_criteria = models.JSONField( + blank=True, + null=True, + help_text="Criteria or metrics for evaluating prompt effectiveness", + ) + + def save(self, *args, **kwargs): + # Ensure dataset_type is always PROMPT for PromptDataset + self.dataset_type = DatasetType.PROMPT + super().save(*args, **kwargs) + + def __str__(self) -> str: + return f"PromptDataset: {self.title}" + + class Meta: + db_table = "prompt_dataset" + verbose_name = "Prompt Dataset" + verbose_name_plural = "Prompt Datasets" diff --git a/api/models/__init__.py b/api/models/__init__.py index 809a711..e9e0d77 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -13,6 +13,7 @@ from api.models.Geography import Geography from api.models.Metadata import Metadata from api.models.Organization import Organization +from api.models.PromptDataset import PromptDataset from api.models.Resource import ( Resource, ResourceDataTable, diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 458aa2a..cb57573 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -15,6 +15,7 @@ Geography, Metadata, Organization, + PromptDataset, Resource, ResourceChartDetails, ResourceChartImage, @@ -31,12 +32,15 @@ from api.schema.extensions import TrackActivity, TrackModelActivity from api.types.type_dataset import DatasetFilter, DatasetOrder, TypeDataset from api.types.type_organization import TypeOrganization +from api.types.type_prompt_metadata import TypePromptDataset, prompt_task_type_enum from api.types.type_resource_chart import TypeResourceChart from api.types.type_resource_chart_image import TypeResourceChartImage from api.utils.enums import ( DatasetAccessType, DatasetLicense, DatasetStatus, + DatasetType, + PromptTaskType, UseCaseStatus, ) from api.utils.graphql_telemetry import trace_resolver @@ -50,6 +54,8 @@ DatasetAccessTypeENUM = strawberry.enum(DatasetAccessType) # type: ignore DatasetLicenseENUM = strawberry.enum(DatasetLicense) # type: ignore +DatasetTypeENUM = strawberry.enum(DatasetType) # type: ignore +PromptTaskTypeENUM = strawberry.enum(PromptTaskType) # type: ignore # Create permission classes dynamically with different operations @@ -104,9 +110,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return False if user.is_superuser: return True - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() if dataset_perm: return dataset_perm.role.can_view org_perm = OrganizationMembership.objects.filter( @@ -154,9 +158,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: if dataset.user and dataset.user == user: return True # Check if user has specific dataset permissions - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() if dataset_perm: return dataset_perm.role.can_view return False @@ -232,9 +234,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return True # Check dataset-specific permissions - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() return dataset_perm and dataset_perm.role.can_change and dataset.status == DatasetStatus.DRAFT.value # type: ignore except Dataset.DoesNotExist: @@ -256,9 +256,7 @@ class UpdateMetadataInput: sectors: Optional[List[uuid.UUID]] = None geographies: Optional[List[int]] = None access_type: Optional[DatasetAccessTypeENUM] = DatasetAccessTypeENUM.PUBLIC - license: Optional[DatasetLicenseENUM] = ( - DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE - ) + license: Optional[DatasetLicenseENUM] = DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE @strawberry.input @@ -268,9 +266,49 @@ class UpdateDatasetInput: description: Optional[str] = None tags: Optional[List[str]] = None access_type: Optional[DatasetAccessTypeENUM] = DatasetAccessTypeENUM.PUBLIC - license: Optional[DatasetLicenseENUM] = ( - DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE - ) + license: Optional[DatasetLicenseENUM] = DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE + + +@strawberry.input +class CreateDatasetInput: + """Input for creating a new dataset with optional type specification.""" + + dataset_type: Optional[DatasetTypeENUM] = DatasetTypeENUM.DATA + + +@strawberry.input +class PromptMetadataInput: + """Input for prompt-specific metadata.""" + + task_type: Optional[PromptTaskTypeENUM] = None + target_languages: Optional[List[str]] = None + domain: Optional[str] = None + target_model_types: Optional[List[str]] = None + prompt_format: Optional[str] = None + has_system_prompt: Optional[bool] = False + has_example_responses: Optional[bool] = False + avg_prompt_length: Optional[int] = None + prompt_count: Optional[int] = None + use_case: Optional[str] = None + evaluation_criteria: Optional[strawberry.scalars.JSON] = None + + +@strawberry.input +class UpdatePromptMetadataInput: + """Input for updating prompt-specific metadata.""" + + dataset: uuid.UUID + task_type: Optional[PromptTaskTypeENUM] = None + target_languages: Optional[List[str]] = None + domain: Optional[str] = None + target_model_types: Optional[List[str]] = None + prompt_format: Optional[str] = None + has_system_prompt: Optional[bool] = None + has_example_responses: Optional[bool] = None + avg_prompt_length: Optional[int] = None + prompt_count: Optional[int] = None + use_case: Optional[str] = None + evaluation_criteria: Optional[strawberry.scalars.JSON] = None @trace_resolver(name="add_update_dataset_metadata", attributes={"component": "dataset"}) @@ -285,9 +323,7 @@ def _add_update_dataset_metadata( metadata_field = Metadata.objects.get(id=metadata_input_item.id) if not metadata_field.enabled: _delete_existing_metadata(dataset) - raise ValueError( - f"Metadata with ID {metadata_input_item.id} is not enabled." - ) + raise ValueError(f"Metadata with ID {metadata_input_item.id} is not enabled.") ds_metadata = DatasetMetadata( dataset=dataset, metadata_item=metadata_field, @@ -296,9 +332,7 @@ def _add_update_dataset_metadata( ds_metadata.save() except Metadata.DoesNotExist as e: _delete_existing_metadata(dataset) - raise ValueError( - f"Metadata with ID {metadata_input_item.id} does not exist." - ) + raise ValueError(f"Metadata with ID {metadata_input_item.id} does not exist.") @trace_resolver(name="update_dataset_tags", attributes={"component": "dataset"}) @@ -307,9 +341,7 @@ def _update_dataset_tags(dataset: Dataset, tags: Optional[List[str]]) -> None: return dataset.tags.clear() for tag in tags: - dataset.tags.add( - Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0] - ) + dataset.tags.add(Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0]) dataset.save() @@ -330,9 +362,7 @@ def _add_update_dataset_sectors(dataset: Dataset, sectors: List[uuid.UUID]) -> N dataset.save() -@trace_resolver( - name="add_update_dataset_geographies", attributes={"component": "dataset"} -) +@trace_resolver(name="add_update_dataset_geographies", attributes={"component": "dataset"}) def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) -> None: """Update geographies for a dataset.""" dataset.geographies.clear() @@ -424,28 +454,24 @@ def get_chart_data( dataset = Dataset.objects.get(id=dataset_id) # Fetch ResourceChartImage for the dataset chart_images = list( - ResourceChartImage.objects.filter(dataset_id=dataset_id).order_by( - "modified" - ) + ResourceChartImage.objects.filter(dataset_id=dataset_id).order_by("modified") + ) + resource_ids = Resource.objects.filter(dataset_id=dataset_id).values_list( + "id", flat=True ) - resource_ids = Resource.objects.filter( - dataset_id=dataset_id - ).values_list("id", flat=True) except Dataset.DoesNotExist: raise ValueError(f"Dataset with ID {dataset_id} does not exist.") else: organization = info.context.context.get("organization") if organization: chart_images = list( - ResourceChartImage.objects.filter( - dataset__organization=organization - ).order_by("modified") + ResourceChartImage.objects.filter(dataset__organization=organization).order_by( + "modified" + ) ) else: chart_images = list( - ResourceChartImage.objects.filter(dataset__user=user).order_by( - "modified" - ) + ResourceChartImage.objects.filter(dataset__user=user).order_by("modified") ) if organization: resource_ids = Resource.objects.filter( @@ -458,9 +484,7 @@ def get_chart_data( # Fetch ResourceChartDetails based on the related Resources chart_details = list( - ResourceChartDetails.objects.filter(resource_id__in=resource_ids).order_by( - "modified" - ) + ResourceChartDetails.objects.filter(resource_id__in=resource_ids).order_by("modified") ) # Convert to Strawberry types after getting lists @@ -485,26 +509,14 @@ def get_chart_data( def get_publishers(self, info: Info) -> List[Union[TypeOrganization, TypeUser]]: """Get all publishers (both individual publishers and organizations) who have published datasets.""" # Get all published datasets - published_datasets = Dataset.objects.filter( - status=DatasetStatus.PUBLISHED.value - ) - published_ds_organizations = published_datasets.values_list( - "organization_id", flat=True - ) - published_usecases = UseCase.objects.filter( - status=UseCaseStatus.PUBLISHED.value - ) - published_uc_organizations = published_usecases.values_list( - "organization_id", flat=True - ) - published_organizations = set(published_ds_organizations) | set( - published_uc_organizations - ) + published_datasets = Dataset.objects.filter(status=DatasetStatus.PUBLISHED.value) + published_ds_organizations = published_datasets.values_list("organization_id", flat=True) + published_usecases = UseCase.objects.filter(status=UseCaseStatus.PUBLISHED.value) + published_uc_organizations = published_usecases.values_list("organization_id", flat=True) + published_organizations = set(published_ds_organizations) | set(published_uc_organizations) # Get unique organizations that have published datasets - org_publishers = Organization.objects.filter( - id__in=published_organizations - ).distinct() + org_publishers = Organization.objects.filter(id__in=published_organizations).distinct() published_ds_users = published_datasets.values_list("user_id", flat=True) published_uc_users = published_usecases.values_list("user_id", flat=True) @@ -533,27 +545,53 @@ class Mutation: "get_data": lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "dataset_type": result.dataset_type, + "organization": (str(result.organization.id) if result.organization else None), }, }, ) - def add_dataset(self, info: Info) -> MutationResponse[TypeDataset]: + def add_dataset( + self, info: Info, create_input: Optional[CreateDatasetInput] = None + ) -> MutationResponse[TypeDataset]: # Get organization from context organization = info.context.context.get("organization") dataspace = info.context.context.get("dataspace") user = info.context.user - dataset = Dataset.objects.create( - organization=organization, - dataspace=dataspace, - title=f"New dataset {datetime.datetime.now().strftime('%d %b %Y - %H:%M:%S')}", - description="", - user=user, - access_type=DatasetAccessType.PUBLIC, - license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, - ) + # Determine dataset type + dataset_type = DatasetType.DATA + if create_input and create_input.dataset_type: + dataset_type = create_input.dataset_type + + # Create title based on dataset type + type_label = "prompt dataset" if dataset_type == DatasetType.PROMPT else "dataset" + title = f"New {type_label} {datetime.datetime.now().strftime('%d %b %Y - %H:%M:%S')}" + + # Create PromptDataset or regular Dataset based on type + dataset: Dataset + if dataset_type == DatasetType.PROMPT: + dataset = PromptDataset.objects.create( + organization=organization, + dataspace=dataspace, + title=title, + description="", + user=user, + access_type=DatasetAccessType.PUBLIC, + license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, + # dataset_type is set automatically in PromptDataset.save() + ) + else: + dataset = Dataset.objects.create( + organization=organization, + dataspace=dataspace, + title=title, + description="", + user=user, + access_type=DatasetAccessType.PUBLIC, + license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, + dataset_type=dataset_type, + ) + DatasetPermission.objects.create( user=user, dataset=dataset, role=Role.objects.get(name="owner") ) @@ -568,9 +606,7 @@ def add_dataset(self, info: Info) -> MutationResponse[TypeDataset]: "get_data": lambda result, update_metadata_input=None, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": { "metadata": True, "description": bool( @@ -592,9 +628,7 @@ def add_update_dataset_metadata( except Dataset.DoesNotExist as e: raise DjangoValidationError(f"Dataset with ID {dataset_id} does not exist.") if dataset.status != DatasetStatus.DRAFT.value: - raise DjangoValidationError( - f"Dataset with ID {dataset_id} is not in draft status." - ) + raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") if update_metadata_input.description: dataset.description = update_metadata_input.description @@ -621,17 +655,12 @@ def add_update_dataset_metadata( get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": { "title": kwargs.get("update_dataset_input").title is not None, - "description": kwargs.get("update_dataset_input").description - is not None, - "access_type": kwargs.get("update_dataset_input").access_type - is not None, - "license": kwargs.get("update_dataset_input").license - is not None, + "description": kwargs.get("update_dataset_input").description is not None, + "access_type": kwargs.get("update_dataset_input").access_type is not None, + "license": kwargs.get("update_dataset_input").license is not None, "tags": kwargs.get("update_dataset_input").tags is not None, }, }, @@ -642,9 +671,7 @@ def add_update_dataset_metadata( name="update_dataset", attributes={"component": "dataset", "operation": "mutation"}, ) - def update_dataset( - self, info: Info, update_dataset_input: UpdateDatasetInput - ) -> TypeDataset: + def update_dataset(self, info: Info, update_dataset_input: UpdateDatasetInput) -> TypeDataset: dataset_id = update_dataset_input.dataset try: dataset = Dataset.objects.get(id=dataset_id) @@ -675,9 +702,7 @@ def update_dataset( get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, ) ], @@ -706,9 +731,7 @@ def publish_dataset(self, info: Info, dataset_id: uuid.UUID) -> TypeDataset: get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": {"status": "DRAFT", "action": "unpublished"}, }, ) @@ -753,3 +776,61 @@ def delete_dataset(self, info: Info, dataset_id: uuid.UUID) -> bool: return True except Dataset.DoesNotExist as e: raise ValueError(f"Dataset with ID {dataset_id} does not exist.") + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[UpdateDatasetPermission], + track_activity={ + "verb": "updated", + "get_data": lambda result, **kwargs: { + "dataset_id": str(result.id), + "dataset_title": result.title, + "updated_fields": {"prompt_metadata": True}, + }, + }, + trace_name="update_prompt_metadata", + trace_attributes={"component": "dataset"}, + ) + def update_prompt_metadata( + self, info: Info, update_input: UpdatePromptMetadataInput + ) -> MutationResponse[TypePromptDataset]: + """Update prompt-specific metadata for a prompt dataset.""" + dataset_id = update_input.dataset + + # Get the PromptDataset directly (it's a child of Dataset via multi-table inheritance) + try: + prompt_dataset = PromptDataset.objects.get(dataset_ptr_id=dataset_id) + except PromptDataset.DoesNotExist: + raise DjangoValidationError( + f"Dataset with ID {dataset_id} is not a prompt dataset or does not exist." + ) + + if prompt_dataset.status != DatasetStatus.DRAFT.value: + raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") + + # Update fields if provided + if update_input.task_type is not None: + prompt_dataset.task_type = update_input.task_type + if update_input.target_languages is not None: + prompt_dataset.target_languages = update_input.target_languages + if update_input.domain is not None: + prompt_dataset.domain = update_input.domain + if update_input.target_model_types is not None: + prompt_dataset.target_model_types = update_input.target_model_types + if update_input.prompt_format is not None: + prompt_dataset.prompt_format = update_input.prompt_format + if update_input.has_system_prompt is not None: + prompt_dataset.has_system_prompt = update_input.has_system_prompt + if update_input.has_example_responses is not None: + prompt_dataset.has_example_responses = update_input.has_example_responses + if update_input.avg_prompt_length is not None: + prompt_dataset.avg_prompt_length = update_input.avg_prompt_length + if update_input.prompt_count is not None: + prompt_dataset.prompt_count = update_input.prompt_count + if update_input.use_case is not None: + prompt_dataset.use_case = update_input.use_case + if update_input.evaluation_criteria is not None: + prompt_dataset.evaluation_criteria = update_input.evaluation_criteria + + prompt_dataset.save() + return MutationResponse.success_response(TypePromptDataset.from_django(prompt_dataset)) diff --git a/api/types/type_dataset.py b/api/types/type_dataset.py index 0934a45..1e0bcae 100644 --- a/api/types/type_dataset.py +++ b/api/types/type_dataset.py @@ -8,19 +8,20 @@ from strawberry.enum import EnumType from strawberry.types import Info -from api.models import Dataset, DatasetMetadata, Resource, Tag +from api.models import Dataset, DatasetMetadata, PromptDataset, Resource, Tag from api.types.base_type import BaseType from api.types.type_dataset_metadata import TypeDatasetMetadata from api.types.type_geo import TypeGeo from api.types.type_organization import TypeOrganization from api.types.type_resource import TypeResource from api.types.type_sector import TypeSector -from api.utils.enums import DatasetStatus +from api.utils.enums import DatasetStatus, DatasetType, PromptTaskType from authorization.types import TypeUser logger = structlog.get_logger("dataspace.type_dataset") dataset_status: EnumType = strawberry.enum(DatasetStatus) # type: ignore +dataset_type_enum: EnumType = strawberry.enum(DatasetType) # type: ignore @strawberry_django.filter(Dataset) @@ -29,6 +30,7 @@ class DatasetFilter: id: Optional[uuid.UUID] status: Optional[dataset_status] + dataset_type: Optional[dataset_type_enum] @strawberry_django.order(Dataset) @@ -55,6 +57,7 @@ class TypeDataset(BaseType): description: Optional[str] slug: str status: dataset_status + dataset_type: dataset_type_enum organization: Optional["TypeOrganization"] created: datetime modified: datetime @@ -107,6 +110,30 @@ def metadata(self) -> List["TypeDatasetMetadata"]: except (AttributeError, DatasetMetadata.DoesNotExist): return [] + @strawberry.field + def prompt_metadata(self) -> Optional[strawberry.scalars.JSON]: + """Get prompt-specific metadata for this dataset (only for PROMPT type datasets).""" + try: + # Check if this dataset is a PromptDataset (via multi-table inheritance) + prompt_dataset = PromptDataset.objects.filter(dataset_ptr_id=self.id).first() + if prompt_dataset: + return { + "task_type": prompt_dataset.task_type, + "target_languages": prompt_dataset.target_languages, + "domain": prompt_dataset.domain, + "target_model_types": prompt_dataset.target_model_types, + "prompt_format": prompt_dataset.prompt_format, + "has_system_prompt": prompt_dataset.has_system_prompt, + "has_example_responses": prompt_dataset.has_example_responses, + "avg_prompt_length": prompt_dataset.avg_prompt_length, + "prompt_count": prompt_dataset.prompt_count, + "use_case": prompt_dataset.use_case, + "evaluation_criteria": prompt_dataset.evaluation_criteria, + } + return None + except (AttributeError, PromptDataset.DoesNotExist): + return None + @strawberry.field def resources(self) -> List["TypeResource"]: """Get resources for this dataset.""" @@ -129,9 +156,7 @@ def formats(self: Any) -> List[str]: except (AttributeError, Resource.DoesNotExist): return [] - @strawberry.field( - description="Get similar datasets for this dataset from elasticsearch index." - ) + @strawberry.field(description="Get similar datasets for this dataset from elasticsearch index.") def similar_datasets(self: Any) -> List["TypeDataset"]: # type: ignore """Get similar datasets for this dataset from elasticsearch index.""" try: @@ -182,9 +207,7 @@ def similar_datasets(self: Any) -> List["TypeDataset"]: # type: ignore sectors = [sector.name for sector in dataset.sectors.all().select_related()] # type: ignore if sectors: - should_queries.append( - ESQ("terms", **{"sectors.raw": sectors, "boost": 2.0}) - ) + should_queries.append(ESQ("terms", **{"sectors.raw": sectors, "boost": 2.0})) # Add metadata similarity # Dataset.metadata is the related_name for DatasetMetadata diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py new file mode 100644 index 0000000..1fcba48 --- /dev/null +++ b/api/types/type_prompt_metadata.py @@ -0,0 +1,61 @@ +"""GraphQL type for PromptDataset.""" + +import uuid +from datetime import datetime +from typing import List, Optional + +import strawberry +import strawberry_django +from strawberry.enum import EnumType +from strawberry.types import Info + +from api.models.PromptDataset import PromptDataset +from api.types.base_type import BaseType +from api.types.type_dataset import TypeDataset +from api.utils.enums import PromptTaskType + +prompt_task_type_enum: EnumType = strawberry.enum(PromptTaskType) # type: ignore + + +@strawberry_django.type( + PromptDataset, + fields="__all__", +) +class TypePromptDataset(TypeDataset): + """ + GraphQL type for PromptDataset. + + Extends TypeDataset with prompt-specific fields. + Inherits all Dataset fields plus adds prompt-specific ones. + """ + + # Prompt-specific fields + task_type: Optional[prompt_task_type_enum] + target_languages: Optional[List[str]] + domain: Optional[str] + target_model_types: Optional[List[str]] + prompt_format: Optional[str] + has_system_prompt: bool + has_example_responses: bool + avg_prompt_length: Optional[int] + prompt_count: Optional[int] + use_case: Optional[str] + evaluation_criteria: Optional[strawberry.scalars.JSON] + + +# Keep TypePromptMetadata as an alias for backward compatibility in nested queries +@strawberry.type +class TypePromptMetadata: + """Prompt-specific metadata fields (for embedding in TypeDataset).""" + + task_type: Optional[prompt_task_type_enum] + target_languages: Optional[List[str]] + domain: Optional[str] + target_model_types: Optional[List[str]] + prompt_format: Optional[str] + has_system_prompt: bool + has_example_responses: bool + avg_prompt_length: Optional[int] + prompt_count: Optional[int] + use_case: Optional[str] + evaluation_criteria: Optional[strawberry.scalars.JSON] diff --git a/api/utils/enums.py b/api/utils/enums.py index 770ad38..6c009d2 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -86,6 +86,28 @@ class DatasetStatus(models.TextChoices): ARCHIVED = "ARCHIVED" +class DatasetType(models.TextChoices): + DATA = "DATA" + PROMPT = "PROMPT" + + +class PromptTaskType(models.TextChoices): + TEXT_GENERATION = "TEXT_GENERATION" + TEXT_CLASSIFICATION = "TEXT_CLASSIFICATION" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + SUMMARIZATION = "SUMMARIZATION" + TRANSLATION = "TRANSLATION" + SENTIMENT_ANALYSIS = "SENTIMENT_ANALYSIS" + NAMED_ENTITY_RECOGNITION = "NAMED_ENTITY_RECOGNITION" + CODE_GENERATION = "CODE_GENERATION" + CONVERSATION = "CONVERSATION" + INSTRUCTION_FOLLOWING = "INSTRUCTION_FOLLOWING" + REASONING = "REASONING" + CREATIVE_WRITING = "CREATIVE_WRITING" + DATA_EXTRACTION = "DATA_EXTRACTION" + OTHER = "OTHER" + + class DatasetAccessType(models.TextChoices): PUBLIC = "PUBLIC" PRIVATE = "PRIVATE" diff --git a/api/views/search_dataset.py b/api/views/search_dataset.py index d1064e3..9dafe09 100644 --- a/api/views/search_dataset.py +++ b/api/views/search_dataset.py @@ -56,6 +56,21 @@ def to_internal_value(self, data: Dict[str, Any]) -> Dict[str, Any]: return cast(Dict[str, Any], super().to_internal_value(data)) +class PromptMetadataSerializer(serializers.Serializer): + """Serializer for PromptMetadata in search results.""" + + task_type = serializers.CharField(allow_null=True) + target_languages = serializers.ListField(child=serializers.CharField(), allow_null=True) + domain = serializers.CharField(allow_null=True) + target_model_types = serializers.ListField(child=serializers.CharField(), allow_null=True) + prompt_format = serializers.CharField(allow_null=True) + has_system_prompt = serializers.BooleanField(allow_null=True) + has_example_responses = serializers.BooleanField(allow_null=True) + avg_prompt_length = serializers.IntegerField(allow_null=True) + prompt_count = serializers.IntegerField(allow_null=True) + use_case = serializers.CharField(allow_null=True) + + class DatasetDocumentSerializer(serializers.ModelSerializer): """Serializer for Dataset document.""" @@ -70,6 +85,8 @@ class DatasetDocumentSerializer(serializers.ModelSerializer): download_count = serializers.IntegerField() trending_score = serializers.FloatField(required=False) is_individual_dataset = serializers.BooleanField() + dataset_type = serializers.CharField(required=False, default="DATA") + prompt_metadata = PromptMetadataSerializer(required=False, allow_null=True) class OrganizationSerializer(serializers.Serializer): name = serializers.CharField() @@ -93,6 +110,7 @@ class Meta: "created", "modified", "status", + "dataset_type", "metadata", "tags", "sectors", @@ -105,6 +123,7 @@ class Meta: "is_individual_dataset", "organization", "user", + "prompt_metadata", ] @@ -119,9 +138,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -148,6 +165,7 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: "formats.raw": "terms", "catalogs.raw": "terms", "geographies.raw": "terms", + "dataset_type": "terms", } for metadata in enabled_metadata: # type: Metadata if metadata.filterable: @@ -175,11 +193,7 @@ def add_aggregations(self, search: Search) -> Search: metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") composite_sources = [ - { - "metadata_label": { - "terms": {"field": "metadata.metadata_item.label"} - } - }, + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, {"metadata_value": {"terms": {"field": "metadata.value"}}}, ] composite_agg = A( @@ -191,13 +205,7 @@ def add_aggregations(self, search: Search) -> Search: "filter", { # type: ignore[arg-type] "bool": { - "must": [ - { - "terms": { - "metadata.metadata_item.label": filterable_metadata - } - } - ] + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] } }, ) @@ -207,19 +215,13 @@ def add_aggregations(self, search: Search) -> Search: return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_dataset"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_dataset"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] for field in self.searchable_fields: - if field.startswith("resources.name") or field.startswith( - "resources.description" - ): + if field.startswith("resources.name") or field.startswith("resources.description"): queries.append( ESQ( "nested", @@ -230,18 +232,14 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), ) ) else: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -256,6 +254,37 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: for filter in filters: if filter in excluded_labels: continue + elif filter == "dataset_type": + # Filter by dataset type (DATA or PROMPT) + search = search.filter("term", dataset_type=filters[filter]) + elif filter == "task_type": + # Filter by prompt task type (nested in prompt_metadata) + search = search.filter( + "nested", + path="prompt_metadata", + query={ + "bool": {"must": {"term": {"prompt_metadata.task_type": filters[filter]}}} + }, + ) + elif filter == "domain": + # Filter by prompt domain (nested in prompt_metadata) + search = search.filter( + "nested", + path="prompt_metadata", + query={"bool": {"must": {"term": {"prompt_metadata.domain": filters[filter]}}}}, + ) + elif filter == "target_languages": + # Filter by target languages (nested in prompt_metadata) + filter_values = filters[filter].split(",") + search = search.filter( + "nested", + path="prompt_metadata", + query={ + "bool": { + "must": {"terms": {"prompt_metadata.target_languages": filter_values}} + } + }, + ) elif filter in ["tags", "sectors", "formats", "catalogs", "geographies"]: raw_filter = filter + ".raw" if raw_filter in self.aggregations: @@ -274,9 +303,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter( "nested", path="metadata", - query={ - "bool": {"must": {"term": {f"metadata.value": filters[filter]}}} - }, + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, ) return search diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index cad599b..907eb73 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -16,6 +16,7 @@ def search( geographies: Optional[List[str]] = None, status: Optional[str] = None, access_type: Optional[str] = None, + dataset_type: Optional[str] = None, sort: Optional[str] = None, page: int = 1, page_size: int = 10, @@ -30,6 +31,7 @@ def search( geographies: Filter by geographies status: Filter by status (DRAFT, PUBLISHED, etc.) access_type: Filter by access type (OPEN, RESTRICTED, etc.) + dataset_type: Filter by dataset type (DATA, PROMPT) sort: Sort order (recent, alphabetical) page: Page number (1-indexed) page_size: Number of results per page @@ -54,6 +56,8 @@ def search( params["status"] = status if access_type: params["access_type"] = access_type + if dataset_type: + params["dataset_type"] = dataset_type if sort: params["sort"] = sort @@ -231,3 +235,197 @@ def get_organization_datasets( limit=limit, offset=offset, ) + + def create(self, dataset_type: str = "DATA") -> Dict[str, Any]: + """ + Create a new dataset using GraphQL. + + Args: + dataset_type: Type of dataset to create (DATA or PROMPT) + + Returns: + Dictionary containing the created dataset information + """ + query = """ + mutation AddDataset($createInput: CreateDatasetInput) { + addDataset(createInput: $createInput) { + success + errors + data { + id + title + description + status + datasetType + createdAt + updatedAt + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"createInput": {"datasetType": dataset_type}}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("addDataset", {}) + return result + + def search_prompts( + self, + query: Optional[str] = None, + task_type: Optional[str] = None, + domain: Optional[str] = None, + target_languages: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + sectors: Optional[List[str]] = None, + sort: Optional[str] = None, + page: int = 1, + page_size: int = 10, + ) -> Dict[str, Any]: + """ + Search for prompt datasets specifically. + + Args: + query: Search query string + task_type: Filter by prompt task type (TEXT_GENERATION, QUESTION_ANSWERING, etc.) + domain: Filter by domain (healthcare, education, etc.) + target_languages: Filter by target languages + tags: Filter by tags + sectors: Filter by sectors + sort: Sort order (recent, alphabetical) + page: Page number (1-indexed) + page_size: Number of results per page + + Returns: + Dictionary containing search results and metadata + """ + params: Dict[str, Any] = { + "page": page, + "page_size": page_size, + "dataset_type": "PROMPT", + } + + if query: + params["q"] = query + if task_type: + params["task_type"] = task_type + if domain: + params["domain"] = domain + if target_languages: + params["target_languages"] = ",".join(target_languages) + if tags: + params["tags"] = ",".join(tags) + if sectors: + params["sectors"] = ",".join(sectors) + if sort: + params["sort"] = sort + + return super().get("/api/search/dataset/", params=params) + + def update_prompt_metadata( + self, + dataset_id: str, + task_type: Optional[str] = None, + target_languages: Optional[List[str]] = None, + domain: Optional[str] = None, + target_model_types: Optional[List[str]] = None, + prompt_format: Optional[str] = None, + has_system_prompt: Optional[bool] = None, + has_example_responses: Optional[bool] = None, + avg_prompt_length: Optional[int] = None, + prompt_count: Optional[int] = None, + use_case: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Update prompt-specific metadata for a prompt dataset. + + Args: + dataset_id: UUID of the prompt dataset + task_type: Type of prompt task + target_languages: List of target languages + domain: Domain/category of prompts + target_model_types: List of target AI model types + prompt_format: Format of prompts + has_system_prompt: Whether prompts include system instructions + has_example_responses: Whether prompts include example responses + avg_prompt_length: Average prompt length + prompt_count: Total number of prompts + use_case: Description of intended use cases + + Returns: + Dictionary containing the updated prompt metadata + """ + query = """ + mutation UpdatePromptMetadata($updateInput: UpdatePromptMetadataInput!) { + updatePromptMetadata(updateInput: $updateInput) { + success + errors + data { + id + title + description + status + datasetType + taskType + targetLanguages + domain + targetModelTypes + promptFormat + hasSystemPrompt + hasExampleResponses + avgPromptLength + promptCount + useCase + } + } + } + """ + + variables: Dict[str, Any] = {"dataset": dataset_id} + + if task_type is not None: + variables["taskType"] = task_type + if target_languages is not None: + variables["targetLanguages"] = target_languages + if domain is not None: + variables["domain"] = domain + if target_model_types is not None: + variables["targetModelTypes"] = target_model_types + if prompt_format is not None: + variables["promptFormat"] = prompt_format + if has_system_prompt is not None: + variables["hasSystemPrompt"] = has_system_prompt + if has_example_responses is not None: + variables["hasExampleResponses"] = has_example_responses + if avg_prompt_length is not None: + variables["avgPromptLength"] = avg_prompt_length + if prompt_count is not None: + variables["promptCount"] = prompt_count + if use_case is not None: + variables["useCase"] = use_case + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"updateInput": variables}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updatePromptMetadata", {}) + return result diff --git a/search/documents/dataset_document.py b/search/documents/dataset_document.py index 8578fb4..85e60b0 100644 --- a/search/documents/dataset_document.py +++ b/search/documents/dataset_document.py @@ -9,10 +9,11 @@ Geography, Metadata, Organization, + PromptDataset, Resource, Sector, ) -from api.utils.enums import DatasetStatus +from api.utils.enums import DatasetStatus, DatasetType from authorization.models import User from DataSpace import settings from search.documents.analysers import html_strip, ngram_analyser @@ -29,9 +30,7 @@ class DatasetDocument(Document): properties={ "value": KeywordField(multi=True), "raw": KeywordField(multi=True), - "metadata_item": fields.ObjectField( - properties={"label": KeywordField(multi=False)} - ), + "metadata_item": fields.ObjectField(properties={"label": KeywordField(multi=False)}), } ) @@ -131,6 +130,22 @@ class DatasetDocument(Document): download_count = fields.IntegerField(attr="download_count") trending_score = fields.FloatField(attr="trending_score") + # Prompt-specific metadata (nested object, only populated for PROMPT type datasets) + prompt_metadata = fields.NestedField( + properties={ + "task_type": KeywordField(), + "target_languages": KeywordField(multi=True), + "domain": KeywordField(), + "target_model_types": KeywordField(multi=True), + "prompt_format": KeywordField(), + "has_system_prompt": fields.BooleanField(), + "has_example_responses": fields.BooleanField(), + "avg_prompt_length": fields.IntegerField(), + "prompt_count": fields.IntegerField(), + "use_case": fields.TextField(analyzer=html_strip), + } + ) + def prepare_metadata(self, instance: Dataset) -> List[Dict[str, Any]]: """Preprocess comma-separated metadata values into arrays.""" processed_metadata: List[Dict[str, Any]] = [] @@ -167,13 +182,36 @@ def prepare_user(self, instance: Dataset) -> Optional[Dict[str, str]]: "name": instance.user.full_name, "bio": instance.user.bio or "", "profile_picture": ( - instance.user.profile_picture.url - if instance.user.profile_picture - else "" + instance.user.profile_picture.url if instance.user.profile_picture else "" ), } return None + def prepare_prompt_metadata(self, instance: Dataset) -> Optional[Dict[str, Any]]: + """Prepare prompt metadata for indexing (only for PROMPT type datasets).""" + if instance.dataset_type != DatasetType.PROMPT: + return None + + try: + # With multi-table inheritance, check if this Dataset has a PromptDataset child + prompt_dataset = PromptDataset.objects.filter(dataset_ptr_id=instance.id).first() + if prompt_dataset: + return { + "task_type": prompt_dataset.task_type, + "target_languages": prompt_dataset.target_languages or [], + "domain": prompt_dataset.domain, + "target_model_types": prompt_dataset.target_model_types or [], + "prompt_format": prompt_dataset.prompt_format, + "has_system_prompt": prompt_dataset.has_system_prompt, + "has_example_responses": prompt_dataset.has_example_responses, + "avg_prompt_length": prompt_dataset.avg_prompt_length, + "prompt_count": prompt_dataset.prompt_count, + "use_case": prompt_dataset.use_case, + } + except PromptDataset.DoesNotExist: + pass + return None + def should_index_object(self, obj: Dataset) -> bool: """Check if the object should be indexed.""" return obj.status == DatasetStatus.PUBLISHED @@ -191,11 +229,7 @@ def delete(self, *args: Any, **kwargs: Any) -> None: def get_queryset(self) -> Any: """Get the queryset for indexing.""" - return ( - super(DatasetDocument, self) - .get_queryset() - .filter(status=DatasetStatus.PUBLISHED) - ) + return super(DatasetDocument, self).get_queryset().filter(status=DatasetStatus.PUBLISHED) def get_instances_from_related( self, @@ -203,6 +237,7 @@ def get_instances_from_related( Resource, Metadata, DatasetMetadata, + PromptDataset, Sector, Organization, User, @@ -218,6 +253,9 @@ def get_instances_from_related( return [obj.dataset for obj in ds_metadata_objects] # type: ignore elif isinstance(related_instance, DatasetMetadata): return related_instance.dataset + elif isinstance(related_instance, PromptDataset): + # PromptDataset IS a Dataset (multi-table inheritance), cast to Dataset + return Dataset.objects.get(pk=related_instance.pk) elif isinstance(related_instance, Sector): return list(related_instance.datasets.all()) elif isinstance(related_instance, Organization): @@ -239,12 +277,14 @@ class Django: "id", "created", "modified", + "dataset_type", ] related_models = [ Resource, Metadata, DatasetMetadata, + PromptDataset, Sector, Organization, User, From 451d40331ff9d046efe10463404af7866657fc37 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 12 Jan 2026 19:34:07 +0530 Subject: [PATCH 054/127] add enums to prompt metadata options --- api/models/PromptDataset.py | 6 ++-- api/types/type_prompt_metadata.py | 30 ++++++++++------ api/utils/enums.py | 57 +++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/api/models/PromptDataset.py b/api/models/PromptDataset.py index 0946059..e7fc737 100644 --- a/api/models/PromptDataset.py +++ b/api/models/PromptDataset.py @@ -3,7 +3,7 @@ from django.db import models from api.models.Dataset import Dataset -from api.utils.enums import DatasetType, PromptTaskType +from api.utils.enums import DatasetType, PromptDomain, PromptFormat, PromptTaskType class PromptDataset(Dataset): @@ -39,6 +39,7 @@ class PromptDataset(Dataset): # Domain/category of prompts domain = models.CharField( max_length=200, + choices=PromptDomain.choices, blank=True, null=True, help_text="Domain or category (e.g., healthcare, education, legal)", @@ -48,12 +49,13 @@ class PromptDataset(Dataset): target_model_types = models.JSONField( blank=True, null=True, - help_text="List of AI model types these prompts are designed for", + help_text="List of AI model types these prompts are designed for (e.g., ['GPT', 'LLAMA'])", ) # Prompt format/template information prompt_format = models.CharField( max_length=100, + choices=PromptFormat.choices, blank=True, null=True, help_text="Format of prompts (e.g., instruction, chat, completion)", diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py index 1fcba48..74c082d 100644 --- a/api/types/type_prompt_metadata.py +++ b/api/types/type_prompt_metadata.py @@ -12,9 +12,19 @@ from api.models.PromptDataset import PromptDataset from api.types.base_type import BaseType from api.types.type_dataset import TypeDataset -from api.utils.enums import PromptTaskType +from api.utils.enums import ( + PromptDomain, + PromptFormat, + PromptTaskType, + TargetLanguage, + TargetModelType, +) prompt_task_type_enum: EnumType = strawberry.enum(PromptTaskType) # type: ignore +prompt_domain_enum: EnumType = strawberry.enum(PromptDomain) # type: ignore +prompt_format_enum: EnumType = strawberry.enum(PromptFormat) # type: ignore +target_language_enum: EnumType = strawberry.enum(TargetLanguage) # type: ignore +target_model_type_enum: EnumType = strawberry.enum(TargetModelType) # type: ignore @strawberry_django.type( @@ -31,15 +41,14 @@ class TypePromptDataset(TypeDataset): # Prompt-specific fields task_type: Optional[prompt_task_type_enum] - target_languages: Optional[List[str]] - domain: Optional[str] - target_model_types: Optional[List[str]] - prompt_format: Optional[str] + target_languages: Optional[List[target_language_enum]] + domain: Optional[prompt_domain_enum] + target_model_types: Optional[List[target_model_type_enum]] + prompt_format: Optional[prompt_format_enum] has_system_prompt: bool has_example_responses: bool avg_prompt_length: Optional[int] prompt_count: Optional[int] - use_case: Optional[str] evaluation_criteria: Optional[strawberry.scalars.JSON] @@ -49,13 +58,12 @@ class TypePromptMetadata: """Prompt-specific metadata fields (for embedding in TypeDataset).""" task_type: Optional[prompt_task_type_enum] - target_languages: Optional[List[str]] - domain: Optional[str] - target_model_types: Optional[List[str]] - prompt_format: Optional[str] + target_languages: Optional[List[target_language_enum]] + domain: Optional[prompt_domain_enum] + target_model_types: Optional[List[target_model_type_enum]] + prompt_format: Optional[prompt_format_enum] has_system_prompt: bool has_example_responses: bool avg_prompt_length: Optional[int] prompt_count: Optional[int] - use_case: Optional[str] evaluation_criteria: Optional[strawberry.scalars.JSON] diff --git a/api/utils/enums.py b/api/utils/enums.py index 6c009d2..b02d27d 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -108,6 +108,63 @@ class PromptTaskType(models.TextChoices): OTHER = "OTHER" +class PromptDomain(models.TextChoices): + HEALTHCARE = "HEALTHCARE" + EDUCATION = "EDUCATION" + LEGAL = "LEGAL" + FINANCE = "FINANCE" + AGRICULTURE = "AGRICULTURE" + ENVIRONMENT = "ENVIRONMENT" + GOVERNMENT = "GOVERNMENT" + TECHNOLOGY = "TECHNOLOGY" + SCIENCE = "SCIENCE" + SOCIAL_SERVICES = "SOCIAL_SERVICES" + TRANSPORTATION = "TRANSPORTATION" + ENERGY = "ENERGY" + GENERAL = "GENERAL" + OTHER = "OTHER" + + +class PromptFormat(models.TextChoices): + INSTRUCTION = "INSTRUCTION" + CHAT = "CHAT" + COMPLETION = "COMPLETION" + FEW_SHOT = "FEW_SHOT" + CHAIN_OF_THOUGHT = "CHAIN_OF_THOUGHT" + ZERO_SHOT = "ZERO_SHOT" + OTHER = "OTHER" + + +class TargetLanguage(models.TextChoices): + ENGLISH = "en" + HINDI = "hi" + TAMIL = "ta" + TELUGU = "te" + BENGALI = "bn" + MARATHI = "mr" + GUJARATI = "gu" + KANNADA = "kn" + MALAYALAM = "ml" + PUNJABI = "pa" + ODIA = "or" + ASSAMESE = "as" + URDU = "ur" + OTHER = "other" + + +class TargetModelType(models.TextChoices): + GPT = "GPT" + LLAMA = "LLAMA" + MISTRAL = "MISTRAL" + GEMINI = "GEMINI" + CLAUDE = "CLAUDE" + FALCON = "FALCON" + BLOOM = "BLOOM" + INDIC_LLM = "INDIC_LLM" + CUSTOM = "CUSTOM" + OTHER = "OTHER" + + class DatasetAccessType(models.TextChoices): PUBLIC = "PUBLIC" PRIVATE = "PRIVATE" From 5013d88bf31ec4a06a24769b25731d5a93844a52 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 12 Jan 2026 19:41:41 +0530 Subject: [PATCH 055/127] fix prompt dataset update permission check --- api/schema/dataset_schema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index cb57573..93c851b 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -204,6 +204,8 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: update_dataset_input = kwargs.get("update_dataset_input") if not update_dataset_input: update_dataset_input = kwargs.get("update_metadata_input") + if not update_dataset_input: + update_dataset_input = kwargs.get("update_input") if not update_dataset_input or not hasattr(update_dataset_input, "dataset"): return False From fa9abcca8c5047e6ef31c6471ad10a69d60345b8 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 12 Jan 2026 20:10:30 +0530 Subject: [PATCH 056/127] update language enum --- api/utils/enums.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/api/utils/enums.py b/api/utils/enums.py index b02d27d..e070455 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -136,20 +136,20 @@ class PromptFormat(models.TextChoices): class TargetLanguage(models.TextChoices): - ENGLISH = "en" - HINDI = "hi" - TAMIL = "ta" - TELUGU = "te" - BENGALI = "bn" - MARATHI = "mr" - GUJARATI = "gu" - KANNADA = "kn" - MALAYALAM = "ml" - PUNJABI = "pa" - ODIA = "or" - ASSAMESE = "as" - URDU = "ur" - OTHER = "other" + ENGLISH = "ENGLISH" + HINDI = "HINDI" + TAMIL = "TAMIL" + TELUGU = "TELUGU" + BENGALI = "BENGALI" + MARATHI = "MARATHI" + GUJARATI = "GUJARATI" + KANNADA = "KANNADA" + MALAYALAM = "MALAYALAM" + PUNJABI = "PUNJABI" + ODIA = "ODIA" + ASSAMESE = "ASSAMESE" + URDU = "URDU" + OTHER = "OTHER" class TargetModelType(models.TextChoices): From c37dc81adc23f44d96c56c9e7c608b940aaa5cc5 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 12:01:52 +0530 Subject: [PATCH 057/127] change dataset_type field to keyword field --- search/documents/dataset_document.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/search/documents/dataset_document.py b/search/documents/dataset_document.py index 85e60b0..203b1b4 100644 --- a/search/documents/dataset_document.py +++ b/search/documents/dataset_document.py @@ -126,6 +126,8 @@ class DatasetDocument(Document): is_individual_dataset = fields.BooleanField(attr="is_individual_dataset") + dataset_type = fields.KeywordField() + has_charts = fields.BooleanField(attr="has_charts") download_count = fields.IntegerField(attr="download_count") trending_score = fields.FloatField(attr="trending_score") @@ -277,7 +279,6 @@ class Django: "id", "created", "modified", - "dataset_type", ] related_models = [ From c770b7a6c2fcadd1da289820dc7d1cff81f3215e Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 13:00:59 +0530 Subject: [PATCH 058/127] move prompt file specific details to separate model --- api/models/PromptDataset.py | 44 +------- api/models/PromptResource.py | 69 ++++++++++++ api/models/__init__.py | 1 + api/schema/dataset_schema.py | 153 ++++++++++++++++----------- api/types/type_dataset.py | 6 -- api/types/type_prompt_metadata.py | 15 +-- api/types/type_resource.py | 48 ++++++--- search/documents/dataset_document.py | 12 --- 8 files changed, 203 insertions(+), 145 deletions(-) create mode 100644 api/models/PromptResource.py diff --git a/api/models/PromptDataset.py b/api/models/PromptDataset.py index e7fc737..72f426d 100644 --- a/api/models/PromptDataset.py +++ b/api/models/PromptDataset.py @@ -3,7 +3,7 @@ from django.db import models from api.models.Dataset import Dataset -from api.utils.enums import DatasetType, PromptDomain, PromptFormat, PromptTaskType +from api.utils.enums import DatasetType, PromptDomain, PromptTaskType class PromptDataset(Dataset): @@ -52,48 +52,6 @@ class PromptDataset(Dataset): help_text="List of AI model types these prompts are designed for (e.g., ['GPT', 'LLAMA'])", ) - # Prompt format/template information - prompt_format = models.CharField( - max_length=100, - choices=PromptFormat.choices, - blank=True, - null=True, - help_text="Format of prompts (e.g., instruction, chat, completion)", - ) - - # Whether prompts include system instructions - has_system_prompt = models.BooleanField( - default=False, - help_text="Whether the prompts include system-level instructions", - ) - - # Whether prompts include example responses - has_example_responses = models.BooleanField( - default=False, - help_text="Whether the prompts include example/expected responses", - ) - - # Average prompt length (for filtering/search) - avg_prompt_length = models.IntegerField( - blank=True, - null=True, - help_text="Average character length of prompts in this dataset", - ) - - # Number of prompts in the dataset - prompt_count = models.IntegerField( - blank=True, - null=True, - help_text="Total number of prompts in this dataset", - ) - - # Use case description - use_case = models.TextField( - blank=True, - null=True, - help_text="Description of intended use cases for these prompts", - ) - # Evaluation criteria or metrics evaluation_criteria = models.JSONField( blank=True, diff --git a/api/models/PromptResource.py b/api/models/PromptResource.py new file mode 100644 index 0000000..dc9d442 --- /dev/null +++ b/api/models/PromptResource.py @@ -0,0 +1,69 @@ +"""PromptResource model - extends Resource with prompt-specific fields.""" + +from django.db import models + +from api.models.Resource import Resource +from api.utils.enums import PromptFormat + + +class PromptResource(models.Model): + """ + PromptResource adds prompt-specific metadata to a Resource. + + This is a OneToOne extension of Resource (not multi-table inheritance) + to store prompt file-specific fields like format, system prompt presence, etc. + + """ + + resource = models.OneToOneField( + Resource, + on_delete=models.CASCADE, + primary_key=True, + related_name="prompt_details", + ) + + # Prompt format/template information + prompt_format = models.CharField( + max_length=100, + choices=PromptFormat.choices, + blank=True, + null=True, + help_text="Format of prompts in this file (e.g., instruction, chat, completion)", + ) + + # Whether prompts include system instructions + has_system_prompt = models.BooleanField( + default=False, + help_text="Whether the prompts in this file include system-level instructions", + ) + + # Whether prompts include example responses + has_example_responses = models.BooleanField( + default=False, + help_text="Whether the prompts in this file include example/expected responses", + ) + + # Average prompt length (for filtering/search) + avg_prompt_length = models.IntegerField( + blank=True, + null=True, + help_text="Average character length of prompts in this file", + ) + + # Number of prompts in this file + prompt_count = models.IntegerField( + blank=True, + null=True, + help_text="Total number of prompts in this file", + ) + + created = models.DateTimeField(auto_now_add=True) + modified = models.DateTimeField(auto_now=True) + + def __str__(self) -> str: + return f"PromptResource: {self.resource.name}" + + class Meta: + db_table = "prompt_resource" + verbose_name = "Prompt Resource" + verbose_name_plural = "Prompt Resources" diff --git a/api/models/__init__.py b/api/models/__init__.py index e9e0d77..14ac679 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -14,6 +14,7 @@ from api.models.Metadata import Metadata from api.models.Organization import Organization from api.models.PromptDataset import PromptDataset +from api.models.PromptResource import PromptResource from api.models.Resource import ( Resource, ResourceDataTable, diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 93c851b..c267e2f 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -33,6 +33,7 @@ from api.types.type_dataset import DatasetFilter, DatasetOrder, TypeDataset from api.types.type_organization import TypeOrganization from api.types.type_prompt_metadata import TypePromptDataset, prompt_task_type_enum +from api.types.type_resource import TypeResource from api.types.type_resource_chart import TypeResourceChart from api.types.type_resource_chart_image import TypeResourceChartImage from api.utils.enums import ( @@ -297,20 +298,26 @@ class PromptMetadataInput: @strawberry.input class UpdatePromptMetadataInput: - """Input for updating prompt-specific metadata.""" + """Input for updating prompt-specific metadata (dataset-level fields).""" dataset: uuid.UUID task_type: Optional[PromptTaskTypeENUM] = None target_languages: Optional[List[str]] = None domain: Optional[str] = None target_model_types: Optional[List[str]] = None + evaluation_criteria: Optional[strawberry.scalars.JSON] = None + + +@strawberry.input +class UpdatePromptResourceInput: + """Input for updating prompt-specific resource metadata (file-level fields).""" + + resource: uuid.UUID prompt_format: Optional[str] = None has_system_prompt: Optional[bool] = None has_example_responses: Optional[bool] = None avg_prompt_length: Optional[int] = None prompt_count: Optional[int] = None - use_case: Optional[str] = None - evaluation_criteria: Optional[strawberry.scalars.JSON] = None @trace_resolver(name="add_update_dataset_metadata", attributes={"component": "dataset"}) @@ -695,6 +702,88 @@ def update_dataset(self, info: Info, update_dataset_input: UpdateDatasetInput) - _update_dataset_tags(dataset, update_dataset_input.tags) return TypeDataset.from_django(dataset) + @strawberry.mutation( + permission_classes=[UpdateDatasetPermission], + ) + @trace_resolver( + name="update_prompt_metadata", + attributes={"component": "dataset"}, + ) + def update_prompt_metadata( + self, info: Info, update_input: UpdatePromptMetadataInput + ) -> MutationResponse[TypePromptDataset]: + """Update prompt-specific metadata for a prompt dataset (dataset-level fields).""" + dataset_id = update_input.dataset + + # Get the PromptDataset directly (it's a child of Dataset via multi-table inheritance) + try: + prompt_dataset = PromptDataset.objects.get(dataset_ptr_id=dataset_id) + except PromptDataset.DoesNotExist: + raise DjangoValidationError( + f"Dataset with ID {dataset_id} is not a prompt dataset or does not exist." + ) + + if prompt_dataset.status != DatasetStatus.DRAFT.value: + raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") + + # Update dataset-level fields if provided + if update_input.task_type is not None: + prompt_dataset.task_type = update_input.task_type + if update_input.target_languages is not None: + prompt_dataset.target_languages = update_input.target_languages + if update_input.domain is not None: + prompt_dataset.domain = update_input.domain + if update_input.target_model_types is not None: + prompt_dataset.target_model_types = update_input.target_model_types + if update_input.evaluation_criteria is not None: + prompt_dataset.evaluation_criteria = update_input.evaluation_criteria + + prompt_dataset.save() + return MutationResponse.success_response(TypePromptDataset.from_django(prompt_dataset)) + + @strawberry.mutation( + permission_classes=[IsAuthenticated], + ) + @trace_resolver( + name="update_prompt_resource", + attributes={"component": "resource"}, + ) + def update_prompt_resource( + self, info: Info, update_input: UpdatePromptResourceInput + ) -> MutationResponse[TypeResource]: + """Update prompt-specific metadata for a resource (file-level fields).""" + from api.models import PromptResource, Resource + + resource_id = update_input.resource + + # Get the Resource + try: + resource = Resource.objects.get(id=resource_id) + except Resource.DoesNotExist: + raise DjangoValidationError(f"Resource with ID {resource_id} does not exist.") + + # Check if the dataset is in draft status + if resource.dataset.status != DatasetStatus.DRAFT.value: + raise DjangoValidationError(f"Cannot update resource - dataset is not in draft status.") + + # Get or create PromptResource + prompt_resource, created = PromptResource.objects.get_or_create(resource=resource) + + # Update file-level fields if provided + if update_input.prompt_format is not None: + prompt_resource.prompt_format = update_input.prompt_format + if update_input.has_system_prompt is not None: + prompt_resource.has_system_prompt = update_input.has_system_prompt + if update_input.has_example_responses is not None: + prompt_resource.has_example_responses = update_input.has_example_responses + if update_input.avg_prompt_length is not None: + prompt_resource.avg_prompt_length = update_input.avg_prompt_length + if update_input.prompt_count is not None: + prompt_resource.prompt_count = update_input.prompt_count + + prompt_resource.save() + return MutationResponse.success_response(TypeResource.from_django(resource)) + @strawberry_django.mutation( handle_django_errors=True, permission_classes=[PublishDatasetPermission], # type: ignore[list-item] @@ -778,61 +867,3 @@ def delete_dataset(self, info: Info, dataset_id: uuid.UUID) -> bool: return True except Dataset.DoesNotExist as e: raise ValueError(f"Dataset with ID {dataset_id} does not exist.") - - @strawberry.mutation - @BaseMutation.mutation( - permission_classes=[UpdateDatasetPermission], - track_activity={ - "verb": "updated", - "get_data": lambda result, **kwargs: { - "dataset_id": str(result.id), - "dataset_title": result.title, - "updated_fields": {"prompt_metadata": True}, - }, - }, - trace_name="update_prompt_metadata", - trace_attributes={"component": "dataset"}, - ) - def update_prompt_metadata( - self, info: Info, update_input: UpdatePromptMetadataInput - ) -> MutationResponse[TypePromptDataset]: - """Update prompt-specific metadata for a prompt dataset.""" - dataset_id = update_input.dataset - - # Get the PromptDataset directly (it's a child of Dataset via multi-table inheritance) - try: - prompt_dataset = PromptDataset.objects.get(dataset_ptr_id=dataset_id) - except PromptDataset.DoesNotExist: - raise DjangoValidationError( - f"Dataset with ID {dataset_id} is not a prompt dataset or does not exist." - ) - - if prompt_dataset.status != DatasetStatus.DRAFT.value: - raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") - - # Update fields if provided - if update_input.task_type is not None: - prompt_dataset.task_type = update_input.task_type - if update_input.target_languages is not None: - prompt_dataset.target_languages = update_input.target_languages - if update_input.domain is not None: - prompt_dataset.domain = update_input.domain - if update_input.target_model_types is not None: - prompt_dataset.target_model_types = update_input.target_model_types - if update_input.prompt_format is not None: - prompt_dataset.prompt_format = update_input.prompt_format - if update_input.has_system_prompt is not None: - prompt_dataset.has_system_prompt = update_input.has_system_prompt - if update_input.has_example_responses is not None: - prompt_dataset.has_example_responses = update_input.has_example_responses - if update_input.avg_prompt_length is not None: - prompt_dataset.avg_prompt_length = update_input.avg_prompt_length - if update_input.prompt_count is not None: - prompt_dataset.prompt_count = update_input.prompt_count - if update_input.use_case is not None: - prompt_dataset.use_case = update_input.use_case - if update_input.evaluation_criteria is not None: - prompt_dataset.evaluation_criteria = update_input.evaluation_criteria - - prompt_dataset.save() - return MutationResponse.success_response(TypePromptDataset.from_django(prompt_dataset)) diff --git a/api/types/type_dataset.py b/api/types/type_dataset.py index 1e0bcae..b8e71ec 100644 --- a/api/types/type_dataset.py +++ b/api/types/type_dataset.py @@ -122,12 +122,6 @@ def prompt_metadata(self) -> Optional[strawberry.scalars.JSON]: "target_languages": prompt_dataset.target_languages, "domain": prompt_dataset.domain, "target_model_types": prompt_dataset.target_model_types, - "prompt_format": prompt_dataset.prompt_format, - "has_system_prompt": prompt_dataset.has_system_prompt, - "has_example_responses": prompt_dataset.has_example_responses, - "avg_prompt_length": prompt_dataset.avg_prompt_length, - "prompt_count": prompt_dataset.prompt_count, - "use_case": prompt_dataset.use_case, "evaluation_criteria": prompt_dataset.evaluation_criteria, } return None diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py index 74c082d..d208f40 100644 --- a/api/types/type_prompt_metadata.py +++ b/api/types/type_prompt_metadata.py @@ -39,16 +39,11 @@ class TypePromptDataset(TypeDataset): Inherits all Dataset fields plus adds prompt-specific ones. """ - # Prompt-specific fields + # Dataset-level prompt fields task_type: Optional[prompt_task_type_enum] target_languages: Optional[List[target_language_enum]] domain: Optional[prompt_domain_enum] target_model_types: Optional[List[target_model_type_enum]] - prompt_format: Optional[prompt_format_enum] - has_system_prompt: bool - has_example_responses: bool - avg_prompt_length: Optional[int] - prompt_count: Optional[int] evaluation_criteria: Optional[strawberry.scalars.JSON] @@ -61,9 +56,15 @@ class TypePromptMetadata: target_languages: Optional[List[target_language_enum]] domain: Optional[prompt_domain_enum] target_model_types: Optional[List[target_model_type_enum]] + evaluation_criteria: Optional[strawberry.scalars.JSON] + + +@strawberry.type +class TypePromptResourceDetails: + """Prompt-specific fields for a resource/file.""" + prompt_format: Optional[prompt_format_enum] has_system_prompt: bool has_example_responses: bool avg_prompt_length: Optional[int] prompt_count: Optional[int] - evaluation_criteria: Optional[strawberry.scalars.JSON] diff --git a/api/types/type_resource.py b/api/types/type_resource.py index 6b541fa..00749c2 100644 --- a/api/types/type_resource.py +++ b/api/types/type_resource.py @@ -8,6 +8,7 @@ from strawberry_django import type from api.models import ( + PromptResource, Resource, ResourceFileDetails, ResourceMetadata, @@ -104,9 +105,7 @@ def metadata(self) -> List[TypeResourceMetadata]: # return [] @strawberry.field - @trace_resolver( - name="get_resource_file_details", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_file_details", attributes={"component": "resource"}) def file_details(self) -> Optional[TypeFileDetails]: """Get file details for this resource. @@ -138,9 +137,7 @@ def schema(self) -> List[TypeResourceSchema]: return [] @strawberry.field - @trace_resolver( - name="get_resource_preview_data", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_preview_data", attributes={"component": "resource"}) def preview_data(self) -> PreviewData: """Get preview data for the resource. @@ -151,9 +148,11 @@ def preview_data(self) -> PreviewData: file_details = getattr(self, "resourcefiledetails", None) if not file_details or not getattr(self, "preview_details", None): return PreviewData(columns=[], rows=[]) - if not getattr( - self, "preview_enabled", False - ) or not file_details.format.lower() in ["csv", "xls", "xlsx"]: + if not getattr(self, "preview_enabled", False) or not file_details.format.lower() in [ + "csv", + "xls", + "xlsx", + ]: return PreviewData(columns=[], rows=[]) try: @@ -169,9 +168,7 @@ def preview_data(self) -> PreviewData: return PreviewData(columns=[], rows=[]) @strawberry.field - @trace_resolver( - name="get_resource_no_of_entries", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_no_of_entries", attributes={"component": "resource"}) def no_of_entries(self) -> int: """Get the number of entries in the resource.""" try: @@ -179,10 +176,7 @@ def no_of_entries(self) -> int: if not file_details: return 0 - if ( - not hasattr(file_details, "format") - or file_details.format.lower() != "csv" - ): + if not hasattr(file_details, "format") or file_details.format.lower() != "csv": return 0 try: @@ -193,3 +187,25 @@ def no_of_entries(self) -> int: except Exception as e: logger.error(f"Error getting number of entries: {str(e)}") return 0 + + @strawberry.field + @trace_resolver(name="get_prompt_details", attributes={"component": "resource"}) + def prompt_details(self) -> Optional[strawberry.scalars.JSON]: + """Get prompt-specific details for this resource (only for prompt datasets). + + Returns: + Optional[JSON]: Prompt details if they exist, None otherwise + """ + try: + prompt_resource = PromptResource.objects.filter(resource_id=self.id).first() + if prompt_resource: + return { + "prompt_format": prompt_resource.prompt_format, + "has_system_prompt": prompt_resource.has_system_prompt, + "has_example_responses": prompt_resource.has_example_responses, + "avg_prompt_length": prompt_resource.avg_prompt_length, + "prompt_count": prompt_resource.prompt_count, + } + return None + except (AttributeError, PromptResource.DoesNotExist): + return None diff --git a/search/documents/dataset_document.py b/search/documents/dataset_document.py index 203b1b4..83f8d6e 100644 --- a/search/documents/dataset_document.py +++ b/search/documents/dataset_document.py @@ -139,12 +139,6 @@ class DatasetDocument(Document): "target_languages": KeywordField(multi=True), "domain": KeywordField(), "target_model_types": KeywordField(multi=True), - "prompt_format": KeywordField(), - "has_system_prompt": fields.BooleanField(), - "has_example_responses": fields.BooleanField(), - "avg_prompt_length": fields.IntegerField(), - "prompt_count": fields.IntegerField(), - "use_case": fields.TextField(analyzer=html_strip), } ) @@ -203,12 +197,6 @@ def prepare_prompt_metadata(self, instance: Dataset) -> Optional[Dict[str, Any]] "target_languages": prompt_dataset.target_languages or [], "domain": prompt_dataset.domain, "target_model_types": prompt_dataset.target_model_types or [], - "prompt_format": prompt_dataset.prompt_format, - "has_system_prompt": prompt_dataset.has_system_prompt, - "has_example_responses": prompt_dataset.has_example_responses, - "avg_prompt_length": prompt_dataset.avg_prompt_length, - "prompt_count": prompt_dataset.prompt_count, - "use_case": prompt_dataset.use_case, } except PromptDataset.DoesNotExist: pass From 3757cf7099a75f623471de863ed614ee0dcbb591 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 13:25:45 +0530 Subject: [PATCH 059/127] handle null values in indexing data --- api/utils/data_indexing.py | 56 ++++++++++---------------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/api/utils/data_indexing.py b/api/utils/data_indexing.py index 7933d68..ce30f0c 100644 --- a/api/utils/data_indexing.py +++ b/api/utils/data_indexing.py @@ -30,9 +30,7 @@ def get_sql_type(pandas_dtype: str) -> str: return "TEXT" -def create_table_for_resource( - resource: Resource, df: pd.DataFrame -) -> Optional[ResourceDataTable]: +def create_table_for_resource(resource: Resource, df: pd.DataFrame) -> Optional[ResourceDataTable]: """Create a database table for the resource data and index it.""" try: # Create ResourceDataTable entry first to get the table name @@ -65,9 +63,7 @@ def create_table_for_resource( df.to_csv(csv_data, index=False, header=False) csv_data.seek(0) - copy_sql = ( - f'COPY "{temp_table}" ({",".join(quoted_columns)}) FROM STDIN WITH CSV' - ) + copy_sql = f'COPY "{temp_table}" ({",".join(quoted_columns)}) FROM STDIN WITH CSV' cursor.copy_expert(copy_sql, csv_data) # Insert from temp to main table with validation @@ -102,14 +98,10 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: try: file_details = resource.resourcefiledetails if not file_details: - logger.info( - f"Resource {resource_id} has no file details, skipping indexing" - ) + logger.info(f"Resource {resource_id} has no file details, skipping indexing") return None except Exception as e: - logger.error( - f"Failed to access file details for resource {resource_id}: {str(e)}" - ) + logger.error(f"Failed to access file details for resource {resource_id}: {str(e)}") return None # Check file format @@ -131,9 +123,7 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: ) return None except Exception as e: - logger.error( - f"Failed to determine format for resource {resource_id}: {str(e)}" - ) + logger.error(f"Failed to determine format for resource {resource_id}: {str(e)}") return None # Load tabular data with timeout protection @@ -144,9 +134,7 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: @contextmanager def timeout(seconds: int) -> Generator[None, None, None]: def handler(signum: int, frame: Any) -> None: - raise TimeoutError( - f"Loading data timed out after {seconds} seconds" - ) + raise TimeoutError(f"Loading data timed out after {seconds} seconds") # Set the timeout handler original_handler = signal.getsignal(signal.SIGALRM) @@ -163,9 +151,7 @@ def handler(signum: int, frame: Any) -> None: with timeout(60): # 60 second timeout for loading data df = load_tabular_data(file_details.file.path, format) except TimeoutError as te: - logger.error( - f"Timeout while loading data for resource {resource_id}: {str(te)}" - ) + logger.error(f"Timeout while loading data for resource {resource_id}: {str(te)}") return None except Exception: # Fallback without timeout if signal.SIGALRM is not available (e.g., on Windows) @@ -204,9 +190,7 @@ def handler(signum: int, frame: Any) -> None: # Rename all but the first occurrence for i, idx in enumerate(indices[1:], 1): df.columns.values[idx] = f"{col}_{i}" - logger.warning( - f"Renamed duplicate columns in resource {resource_id}" - ) + logger.warning(f"Renamed duplicate columns in resource {resource_id}") except Exception as e: logger.error( f"Failed to sanitize column names for resource {resource_id}: {str(e)}" @@ -229,9 +213,7 @@ def handler(signum: int, frame: Any) -> None: existing_table = ResourceDataTable.objects.get(resource=resource) try: with connections[DATA_DB].cursor() as cursor: - cursor.execute( - f'DROP TABLE IF EXISTS "{existing_table.table_name}"' - ) + cursor.execute(f'DROP TABLE IF EXISTS "{existing_table.table_name}"') except Exception as drop_error: logger.error( f"Failed to drop existing table for resource {resource_id}: {str(drop_error)}" @@ -292,15 +274,11 @@ def handler(signum: int, frame: Any) -> None: # For description, preserve existing if available, otherwise auto-generate description = f"Description of column {col}" if col in existing_schemas: - existing_description = existing_schemas[col][ - "description" - ] + existing_description = existing_schemas[col]["description"] # Check for None and non-auto-generated descriptions if existing_description is not None: description = existing_description - logger.debug( - f"Preserved custom description for column {col}" - ) + logger.debug(f"Preserved custom description for column {col}") # Create the schema entry ResourceSchema.objects.create( @@ -393,9 +371,7 @@ def get_row_count(resource: Resource) -> int: import traceback error_tb = traceback.format_exc() - logger.error( - f"Error getting row count for resource {resource.id}:\n{str(e)}\n{error_tb}" - ) + logger.error(f"Error getting row count for resource {resource.id}:\n{str(e)}\n{error_tb}") return 0 @@ -429,9 +405,7 @@ def get_preview_data(resource: Resource) -> Optional[PreviewData]: try: if is_all_entries: # For safety, always limit the number of rows returned even for 'all entries' - cursor.execute( - f'SELECT * FROM "{data_table.table_name}" LIMIT 1000' - ) + cursor.execute(f'SELECT * FROM "{data_table.table_name}" LIMIT 1000') else: # Ensure we have valid integer values for the calculation start = int(start_entry) if start_entry is not None else 0 @@ -443,8 +417,8 @@ def get_preview_data(resource: Resource) -> Optional[PreviewData]: columns = [desc[0] for desc in cursor.description] data = cursor.fetchall() - # Convert tuples to lists - rows = [list(row) for row in data] + # Convert tuples to lists and sanitize None values to empty strings + rows = [[cell if cell is not None else "" for cell in row] for row in data] return PreviewData(columns=columns, rows=rows) except Exception as query_error: logger.error( From 7723dc23c32f5d75c4bb83e8e0b9abd28969e5ea Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 13:32:10 +0530 Subject: [PATCH 060/127] retrun TypePromptResourceDetails instead of json --- api/types/type_resource.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/api/types/type_resource.py b/api/types/type_resource.py index 00749c2..c0caeec 100644 --- a/api/types/type_resource.py +++ b/api/types/type_resource.py @@ -18,6 +18,7 @@ from api.types.base_type import BaseType from api.types.type_file_details import TypeFileDetails from api.types.type_preview_data import PreviewData +from api.types.type_prompt_metadata import TypePromptResourceDetails from api.types.type_resource_metadata import TypeResourceMetadata from api.utils.data_indexing import get_preview_data, get_row_count from api.utils.graphql_telemetry import trace_resolver @@ -190,22 +191,22 @@ def no_of_entries(self) -> int: @strawberry.field @trace_resolver(name="get_prompt_details", attributes={"component": "resource"}) - def prompt_details(self) -> Optional[strawberry.scalars.JSON]: + def prompt_details(self) -> Optional[TypePromptResourceDetails]: """Get prompt-specific details for this resource (only for prompt datasets). Returns: - Optional[JSON]: Prompt details if they exist, None otherwise + Optional[TypePromptResourceDetails]: Prompt details if they exist, None otherwise """ try: prompt_resource = PromptResource.objects.filter(resource_id=self.id).first() if prompt_resource: - return { - "prompt_format": prompt_resource.prompt_format, - "has_system_prompt": prompt_resource.has_system_prompt, - "has_example_responses": prompt_resource.has_example_responses, - "avg_prompt_length": prompt_resource.avg_prompt_length, - "prompt_count": prompt_resource.prompt_count, - } + return TypePromptResourceDetails( + prompt_format=prompt_resource.prompt_format, + has_system_prompt=prompt_resource.has_system_prompt, + has_example_responses=prompt_resource.has_example_responses, + avg_prompt_length=prompt_resource.avg_prompt_length, + prompt_count=prompt_resource.prompt_count, + ) return None except (AttributeError, PromptResource.DoesNotExist): return None From b589e8bbb2d66b01fd0772e5fe1f00a202f33e73 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 13:36:19 +0530 Subject: [PATCH 061/127] move TypePromptResourceDetails to separate file to fix circular dependency --- api/types/type_prompt_metadata.py | 13 ------------- api/types/type_prompt_resource_details.py | 22 ++++++++++++++++++++++ api/types/type_resource.py | 2 +- 3 files changed, 23 insertions(+), 14 deletions(-) create mode 100644 api/types/type_prompt_resource_details.py diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py index d208f40..337ebe0 100644 --- a/api/types/type_prompt_metadata.py +++ b/api/types/type_prompt_metadata.py @@ -14,7 +14,6 @@ from api.types.type_dataset import TypeDataset from api.utils.enums import ( PromptDomain, - PromptFormat, PromptTaskType, TargetLanguage, TargetModelType, @@ -22,7 +21,6 @@ prompt_task_type_enum: EnumType = strawberry.enum(PromptTaskType) # type: ignore prompt_domain_enum: EnumType = strawberry.enum(PromptDomain) # type: ignore -prompt_format_enum: EnumType = strawberry.enum(PromptFormat) # type: ignore target_language_enum: EnumType = strawberry.enum(TargetLanguage) # type: ignore target_model_type_enum: EnumType = strawberry.enum(TargetModelType) # type: ignore @@ -57,14 +55,3 @@ class TypePromptMetadata: domain: Optional[prompt_domain_enum] target_model_types: Optional[List[target_model_type_enum]] evaluation_criteria: Optional[strawberry.scalars.JSON] - - -@strawberry.type -class TypePromptResourceDetails: - """Prompt-specific fields for a resource/file.""" - - prompt_format: Optional[prompt_format_enum] - has_system_prompt: bool - has_example_responses: bool - avg_prompt_length: Optional[int] - prompt_count: Optional[int] diff --git a/api/types/type_prompt_resource_details.py b/api/types/type_prompt_resource_details.py new file mode 100644 index 0000000..dca6681 --- /dev/null +++ b/api/types/type_prompt_resource_details.py @@ -0,0 +1,22 @@ +"""GraphQL type for prompt resource details.""" + +from typing import Optional + +import strawberry +from strawberry.enum import EnumType + +from api.utils.enums import PromptFormat + +# Create the enum for GraphQL schema +prompt_format_enum: EnumType = strawberry.enum(PromptFormat) # type: ignore + + +@strawberry.type +class TypePromptResourceDetails: + """Prompt-specific fields for a resource/file.""" + + prompt_format: Optional[prompt_format_enum] + has_system_prompt: bool + has_example_responses: bool + avg_prompt_length: Optional[int] + prompt_count: Optional[int] diff --git a/api/types/type_resource.py b/api/types/type_resource.py index c0caeec..2053202 100644 --- a/api/types/type_resource.py +++ b/api/types/type_resource.py @@ -18,7 +18,7 @@ from api.types.base_type import BaseType from api.types.type_file_details import TypeFileDetails from api.types.type_preview_data import PreviewData -from api.types.type_prompt_metadata import TypePromptResourceDetails +from api.types.type_prompt_resource_details import TypePromptResourceDetails from api.types.type_resource_metadata import TypeResourceMetadata from api.utils.data_indexing import get_preview_data, get_row_count from api.utils.graphql_telemetry import trace_resolver From 2cf75d0dd8978b87b6897b6cee5810057c014917 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 13 Jan 2026 13:52:05 +0530 Subject: [PATCH 062/127] add emum query --- api/schema/dataset_schema.py | 2 ++ api/types/type_prompt_metadata.py | 12 ------------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index c267e2f..edb3959 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -41,6 +41,7 @@ DatasetLicense, DatasetStatus, DatasetType, + PromptFormat, PromptTaskType, UseCaseStatus, ) @@ -57,6 +58,7 @@ DatasetLicenseENUM = strawberry.enum(DatasetLicense) # type: ignore DatasetTypeENUM = strawberry.enum(DatasetType) # type: ignore PromptTaskTypeENUM = strawberry.enum(PromptTaskType) # type: ignore +PromptFormatENUM = strawberry.enum(PromptFormat) # type: ignore # Create permission classes dynamically with different operations diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py index 337ebe0..4754aaf 100644 --- a/api/types/type_prompt_metadata.py +++ b/api/types/type_prompt_metadata.py @@ -43,15 +43,3 @@ class TypePromptDataset(TypeDataset): domain: Optional[prompt_domain_enum] target_model_types: Optional[List[target_model_type_enum]] evaluation_criteria: Optional[strawberry.scalars.JSON] - - -# Keep TypePromptMetadata as an alias for backward compatibility in nested queries -@strawberry.type -class TypePromptMetadata: - """Prompt-specific metadata fields (for embedding in TypeDataset).""" - - task_type: Optional[prompt_task_type_enum] - target_languages: Optional[List[target_language_enum]] - domain: Optional[prompt_domain_enum] - target_model_types: Optional[List[target_model_type_enum]] - evaluation_criteria: Optional[strawberry.scalars.JSON] From afcdf0c1efed326554cb159a690c440b041b203d Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 13:38:40 +0530 Subject: [PATCH 063/127] add purpose to prompt metadata --- api/models/PromptDataset.py | 11 ++++++++++- api/schema/dataset_schema.py | 7 ++++++- api/utils/enums.py | 7 +++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/api/models/PromptDataset.py b/api/models/PromptDataset.py index 72f426d..89cd1d9 100644 --- a/api/models/PromptDataset.py +++ b/api/models/PromptDataset.py @@ -3,7 +3,7 @@ from django.db import models from api.models.Dataset import Dataset -from api.utils.enums import DatasetType, PromptDomain, PromptTaskType +from api.utils.enums import DatasetType, PromptDomain, PromptPurpose, PromptTaskType class PromptDataset(Dataset): @@ -59,6 +59,15 @@ class PromptDataset(Dataset): help_text="Criteria or metrics for evaluating prompt effectiveness", ) + # Purpose of the prompt dataset + purpose = models.CharField( + max_length=200, + choices=PromptPurpose.choices, + blank=True, + null=True, + help_text="Purpose of the prompt dataset", + ) + def save(self, *args, **kwargs): # Ensure dataset_type is always PROMPT for PromptDataset self.dataset_type = DatasetType.PROMPT diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index edb3959..932558a 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -42,6 +42,7 @@ DatasetStatus, DatasetType, PromptFormat, + PromptPurpose, PromptTaskType, UseCaseStatus, ) @@ -59,6 +60,7 @@ DatasetTypeENUM = strawberry.enum(DatasetType) # type: ignore PromptTaskTypeENUM = strawberry.enum(PromptTaskType) # type: ignore PromptFormatENUM = strawberry.enum(PromptFormat) # type: ignore +PromptPurposeENUM = strawberry.enum(PromptPurpose) # type: ignore # Create permission classes dynamically with different operations @@ -294,8 +296,8 @@ class PromptMetadataInput: has_example_responses: Optional[bool] = False avg_prompt_length: Optional[int] = None prompt_count: Optional[int] = None - use_case: Optional[str] = None evaluation_criteria: Optional[strawberry.scalars.JSON] = None + purpose: Optional[PromptPurposeENUM] = None @strawberry.input @@ -308,6 +310,7 @@ class UpdatePromptMetadataInput: domain: Optional[str] = None target_model_types: Optional[List[str]] = None evaluation_criteria: Optional[strawberry.scalars.JSON] = None + purpose: Optional[PromptPurposeENUM] = None @strawberry.input @@ -739,6 +742,8 @@ def update_prompt_metadata( prompt_dataset.target_model_types = update_input.target_model_types if update_input.evaluation_criteria is not None: prompt_dataset.evaluation_criteria = update_input.evaluation_criteria + if update_input.purpose is not None: + prompt_dataset.purpose = update_input.purpose prompt_dataset.save() return MutationResponse.success_response(TypePromptDataset.from_django(prompt_dataset)) diff --git a/api/utils/enums.py b/api/utils/enums.py index e070455..826cd77 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -108,6 +108,13 @@ class PromptTaskType(models.TextChoices): OTHER = "OTHER" +class PromptPurpose(models.TextChoices): + RESEARCH = "RESEARCH" + EDUCATION = "EDUCATION" + EVALUATION = "EVALUATION" + OTHER = "OTHER" + + class PromptDomain(models.TextChoices): HEALTHCARE = "HEALTHCARE" EDUCATION = "EDUCATION" From b4e7d4d452632b1771392ab51b5b28f694bdbcc9 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 13:40:11 +0530 Subject: [PATCH 064/127] add prompt api to sdk --- dataspace_sdk/resources/datasets.py | 170 ++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index 907eb73..f2d2e6f 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -280,6 +280,176 @@ def create(self, dataset_type: str = "DATA") -> Dict[str, Any]: result: Dict[str, Any] = response.get("data", {}).get("addDataset", {}) return result + def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: + """ + Get a prompt dataset by ID with prompt-specific metadata. + + Args: + dataset_id: UUID of the prompt dataset + + Returns: + Dictionary containing prompt dataset information including prompt metadata + """ + query = """ + query GetPromptDataset($id: UUID!) { + dataset(id: $id) { + id + title + description + status + accessType + license + datasetType + createdAt + updatedAt + organization { + id + name + description + } + tags { + id + value + } + sectors { + id + name + } + geographies { + id + name + } + resources { + id + title + description + fileDetails + schema + createdAt + updatedAt + promptFormat + hasSystemPrompt + hasExampleResponses + avgPromptLength + promptCount + } + promptMetadata { + taskType + targetLanguages + domain + targetModelTypes + + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"id": dataset_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("dataset", {}) + return result + + def list_prompts( + self, + status: Optional[str] = None, + task_type: Optional[str] = None, + domain: Optional[str] = None, + organization_id: Optional[str] = None, + limit: int = 10, + offset: int = 0, + ) -> Any: + """ + List prompt datasets with pagination using GraphQL. + + Args: + status: Filter by status (DRAFT, PUBLISHED, etc.) + task_type: Filter by prompt task type + domain: Filter by domain + organization_id: Filter by organization + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of prompt datasets + """ + query = """ + query ListPromptDatasets($filters: DatasetFilter, $pagination: OffsetPaginationInput) { + datasets(filters: $filters, pagination: $pagination) { + id + title + description + status + accessType + datasetType + createdAt + updatedAt + organization { + id + name + } + tags { + id + value + } + promptMetadata { + taskType + targetLanguages + domain + targetModelTypes + + } + resources { + id + title + fileDetails + promptFormat + hasSystemPrompt + hasExampleResponses + promptCount + } + } + } + """ + + filters: Dict[str, Any] = {"datasetType": "PROMPT"} + if status: + filters["status"] = status + if organization_id: + filters["organization"] = {"id": {"exact": organization_id}} + + variables: Dict[str, Any] = { + "pagination": {"limit": limit, "offset": offset}, + "filters": filters, + } + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": variables, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + data = response.get("data", {}) + datasets_result: Any = data.get("datasets", []) if isinstance(data, dict) else [] + return datasets_result + def search_prompts( self, query: Optional[str] = None, From 8f381b6f9ac6507018baa78387751164d7b25dd0 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 08:11:45 +0000 Subject: [PATCH 065/127] Bump SDK version to 0.4.5 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 3e6b331..51039da 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.4" +__version__ = "0.4.5" diff --git a/pyproject.toml b/pyproject.toml index 0a5fdaa..47e2d5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.4" +version = "0.4.5" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.4" +python_version = "0.4.5" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 144a6116e2133f78f85bc2af5e4a7a3c467b2b53 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 09:45:35 +0000 Subject: [PATCH 066/127] Bump SDK version to 0.4.7 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 51039da..8f42edf 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.5" +__version__ = "0.4.7" diff --git a/pyproject.toml b/pyproject.toml index 47e2d5b..74b9c85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.5" +version = "0.4.7" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.5" +python_version = "0.4.7" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From e9b92ef65887588c00a8d7622dc7022b3ab84670 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 14:25:56 +0530 Subject: [PATCH 067/127] fix query to fetch prompt datasets --- dataspace_sdk/resources/datasets.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index f2d2e6f..a670133 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -392,8 +392,7 @@ def list_prompts( status accessType datasetType - createdAt - updatedAt + created organization { id name @@ -402,21 +401,20 @@ def list_prompts( id value } - promptMetadata { - taskType - targetLanguages - domain - targetModelTypes - - } + promptMetadata resources { id - title - fileDetails - promptFormat - hasSystemPrompt - hasExampleResponses - promptCount + name + fileDetails { + format + size + } + promptDetails { + promptFormat + hasSystemPrompt + hasExampleResponses + promptCount + } } } } From 85f6e5a4fd2d0a8dc07ab7b0855c88cfe7ada712 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 15:53:42 +0530 Subject: [PATCH 068/127] fix dataset queries --- dataspace_sdk/resources/datasets.py | 55 +++++++++++++---------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index a670133..259a473 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -82,8 +82,8 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: status accessType license - createdAt - updatedAt + created + updated organization { id name @@ -103,12 +103,13 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: } resources { id - title + name description - fileDetails + fileDetails { + format + size + } schema - createdAt - updatedAt } } } @@ -158,8 +159,8 @@ def list_all( status accessType license - createdAt - updatedAt + created + updated organization { id name @@ -257,8 +258,8 @@ def create(self, dataset_type: str = "DATA") -> Dict[str, Any]: description status datasetType - createdAt - updatedAt + created + updated } } } @@ -300,8 +301,8 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: accessType license datasetType - createdAt - updatedAt + created + updated organization { id name @@ -321,25 +322,19 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: } resources { id - title - description - fileDetails - schema - createdAt - updatedAt - promptFormat - hasSystemPrompt - hasExampleResponses - avgPromptLength - promptCount - } - promptMetadata { - taskType - targetLanguages - domain - targetModelTypes - + name + fileDetails { + format + size + } + promptDetails { + promptFormat + hasSystemPrompt + hasExampleResponses + promptCount + } } + promptMetadata } } """ From 8d451c984350100228287aff6bcef821e1624a96 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 10:24:40 +0000 Subject: [PATCH 069/127] Bump SDK version to 0.4.8 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 8f42edf..d199d88 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.7" +__version__ = "0.4.8" diff --git a/pyproject.toml b/pyproject.toml index 74b9c85..9c1b6aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.7" +version = "0.4.8" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.7" +python_version = "0.4.8" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 25bfa0dc712aa011bce7c96046fb28f55e57a156 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 16:35:03 +0530 Subject: [PATCH 070/127] use direct filters and pagination for datasets query --- api/schema/dataset_schema.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 932558a..c7429f6 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -387,11 +387,7 @@ def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) @strawberry.type class Query: - @strawberry_django.field( - filters=DatasetFilter, - pagination=True, - order=DatasetOrder, - ) + @strawberry_django.field @trace_resolver(name="datasets", attributes={"component": "dataset"}) def datasets( self, @@ -399,7 +395,7 @@ def datasets( filters: Optional[DatasetFilter] = strawberry.UNSET, pagination: Optional[OffsetPaginationInput] = strawberry.UNSET, order: Optional[DatasetOrder] = strawberry.UNSET, - ) -> List[TypeDataset]: + ) -> Any: """Get all datasets.""" organization = info.context.context.get("organization") user = info.context.user @@ -430,16 +426,7 @@ def datasets( # For non-authenticated users, return empty queryset queryset = Dataset.objects.none() - if filters is not strawberry.UNSET: - queryset = strawberry_django.filters.apply(filters, queryset, info) - - if order is not strawberry.UNSET: - queryset = strawberry_django.ordering.apply(order, queryset, info) - - if pagination is not strawberry.UNSET: - queryset = strawberry_django.pagination.apply(pagination, queryset) - - return TypeDataset.from_django_list(queryset) + return queryset @strawberry.field( permission_classes=[AllowPublishedDatasets], # type: ignore[list-item] From 89c8ea0ccdc59ab49abb90de4c84995c36570ff3 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 16:39:18 +0530 Subject: [PATCH 071/127] revert dataset query changes --- api/schema/dataset_schema.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index c7429f6..b7c670f 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -387,7 +387,11 @@ def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) @strawberry.type class Query: - @strawberry_django.field + @strawberry_django.field( + filters=DatasetFilter, + pagination=True, + order=DatasetOrder, + ) @trace_resolver(name="datasets", attributes={"component": "dataset"}) def datasets( self, @@ -395,7 +399,7 @@ def datasets( filters: Optional[DatasetFilter] = strawberry.UNSET, pagination: Optional[OffsetPaginationInput] = strawberry.UNSET, order: Optional[DatasetOrder] = strawberry.UNSET, - ) -> Any: + ) -> list[TypeDataset]: """Get all datasets.""" organization = info.context.context.get("organization") user = info.context.user @@ -426,7 +430,16 @@ def datasets( # For non-authenticated users, return empty queryset queryset = Dataset.objects.none() - return queryset + if filters is not strawberry.UNSET: + queryset = strawberry_django.filters.apply(filters, queryset, info) + + if order is not strawberry.UNSET: + queryset = strawberry_django.ordering.apply(order, queryset, info) + + if pagination is not strawberry.UNSET: + queryset = strawberry_django.pagination.apply(pagination, queryset) + + return TypeDataset.from_django_list(queryset) @strawberry.field( permission_classes=[AllowPublishedDatasets], # type: ignore[list-item] From 14c27f9ffc9244acb07dbec7ed25fb72a39e0e3e Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 18:13:17 +0530 Subject: [PATCH 072/127] try strawberry.field instead of strawberry_django.field --- api/schema/dataset_schema.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index b7c670f..0b7482f 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -387,18 +387,14 @@ def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) @strawberry.type class Query: - @strawberry_django.field( - filters=DatasetFilter, - pagination=True, - order=DatasetOrder, - ) + @strawberry.field @trace_resolver(name="datasets", attributes={"component": "dataset"}) def datasets( self, info: Info, - filters: Optional[DatasetFilter] = strawberry.UNSET, - pagination: Optional[OffsetPaginationInput] = strawberry.UNSET, - order: Optional[DatasetOrder] = strawberry.UNSET, + filters: Optional[DatasetFilter] = None, + pagination: Optional[OffsetPaginationInput] = None, + order: Optional[DatasetOrder] = None, ) -> list[TypeDataset]: """Get all datasets.""" organization = info.context.context.get("organization") @@ -430,13 +426,16 @@ def datasets( # For non-authenticated users, return empty queryset queryset = Dataset.objects.none() - if filters is not strawberry.UNSET: + # Apply filters FIRST (before any slicing) + if filters is not None: queryset = strawberry_django.filters.apply(filters, queryset, info) - if order is not strawberry.UNSET: + # Apply ordering SECOND + if order is not None: queryset = strawberry_django.ordering.apply(order, queryset, info) - if pagination is not strawberry.UNSET: + # Apply pagination LAST (this will slice the queryset) + if pagination is not None: queryset = strawberry_django.pagination.apply(pagination, queryset) return TypeDataset.from_django_list(queryset) From a2e26a44677c1b4d40eef35d17e02dbd4976eca0 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 23:40:47 +0530 Subject: [PATCH 073/127] add ability to include published datasets in datasets query --- api/schema/dataset_schema.py | 4 ++++ dataspace_sdk/resources/datasets.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 0b7482f..6591c7a 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -395,6 +395,7 @@ def datasets( filters: Optional[DatasetFilter] = None, pagination: Optional[OffsetPaginationInput] = None, order: Optional[DatasetOrder] = None, + include_public: Optional[bool] = False, ) -> list[TypeDataset]: """Get all datasets.""" organization = info.context.context.get("organization") @@ -438,6 +439,9 @@ def datasets( if pagination is not None: queryset = strawberry_django.pagination.apply(pagination, queryset) + if include_public: + queryset = queryset | Dataset.objects.filter(status=DatasetStatus.PUBLISHED) + return TypeDataset.from_django_list(queryset) @strawberry.field( diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index 259a473..6c6ae10 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -361,6 +361,7 @@ def list_prompts( task_type: Optional[str] = None, domain: Optional[str] = None, organization_id: Optional[str] = None, + include_public: Optional[bool] = False, limit: int = 10, offset: int = 0, ) -> Any: @@ -372,6 +373,7 @@ def list_prompts( task_type: Filter by prompt task type domain: Filter by domain organization_id: Filter by organization + include_public: Include public datasets limit: Number of results to return offset: Number of results to skip @@ -424,6 +426,7 @@ def list_prompts( variables: Dict[str, Any] = { "pagination": {"limit": limit, "offset": offset}, "filters": filters, + "include_public": include_public, } response = self.post( From 4d92cead484ebd0e8881d66730102052b3877a4b Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 14 Jan 2026 23:48:01 +0530 Subject: [PATCH 074/127] Include public datasets if requested before filters/pagination --- api/schema/dataset_schema.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 6591c7a..e0adc04 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -427,6 +427,10 @@ def datasets( # For non-authenticated users, return empty queryset queryset = Dataset.objects.none() + # Include public datasets if requested (BEFORE filters/pagination) + if include_public: + queryset = queryset | Dataset.objects.filter(status=DatasetStatus.PUBLISHED) + # Apply filters FIRST (before any slicing) if filters is not None: queryset = strawberry_django.filters.apply(filters, queryset, info) @@ -439,9 +443,6 @@ def datasets( if pagination is not None: queryset = strawberry_django.pagination.apply(pagination, queryset) - if include_public: - queryset = queryset | Dataset.objects.filter(status=DatasetStatus.PUBLISHED) - return TypeDataset.from_django_list(queryset) @strawberry.field( From b44e886dbcdda17893dd7055baee537651373f93 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 18:20:26 +0000 Subject: [PATCH 075/127] Bump SDK version to 0.4.9 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index d199d88..0064c7a 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.8" +__version__ = "0.4.9" diff --git a/pyproject.toml b/pyproject.toml index 9c1b6aa..5af4c4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.8" +version = "0.4.9" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.8" +python_version = "0.4.9" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 700dc77d6c8a9bc1634c87941c53291d8f0f8999 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 15 Jan 2026 00:11:26 +0530 Subject: [PATCH 076/127] fix datasets query ti include variable passed --- dataspace_sdk/resources/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index 6c6ae10..085a072 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -381,8 +381,8 @@ def list_prompts( List of prompt datasets """ query = """ - query ListPromptDatasets($filters: DatasetFilter, $pagination: OffsetPaginationInput) { - datasets(filters: $filters, pagination: $pagination) { + query ListPromptDatasets($filters: DatasetFilter, $pagination: OffsetPaginationInput, $include_public: Boolean) { + datasets(filters: $filters, pagination: $pagination, includePublic: $include_public) { id title description From b912828bbaba0f8f7ddb67e33184784b10c6be4f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 19:09:34 +0000 Subject: [PATCH 077/127] Bump SDK version to 0.4.10 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 0064c7a..fd553b7 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.9" +__version__ = "0.4.10" diff --git a/pyproject.toml b/pyproject.toml index 5af4c4c..7159b48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.9" +version = "0.4.10" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.9" +python_version = "0.4.10" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 3724fd7b9bb6c634c03662283d2df99180d4f3d6 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 15 Jan 2026 00:49:30 +0530 Subject: [PATCH 078/127] fix dataset query in sdk --- dataspace_sdk/resources/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index 085a072..28acf45 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -293,7 +293,7 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: """ query = """ query GetPromptDataset($id: UUID!) { - dataset(id: $id) { + get_dataset(id: $id) { id title description @@ -352,7 +352,7 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("dataset", {}) + result: Dict[str, Any] = response.get("data", {}).get("get_dataset", {}) return result def list_prompts( From b52f523de939135eb38868c0adf0761a14081d67 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 19:20:43 +0000 Subject: [PATCH 079/127] Bump SDK version to 0.4.11 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index fd553b7..465fa21 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.10" +__version__ = "0.4.11" diff --git a/pyproject.toml b/pyproject.toml index 7159b48..d8bcceb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.10" +version = "0.4.11" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.10" +python_version = "0.4.11" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 4a24cbdd4ae183b8511520335326de1d40432dd9 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 15 Jan 2026 01:01:44 +0530 Subject: [PATCH 080/127] fix dataset query in sdk --- dataspace_sdk/resources/datasets.py | 34 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index 28acf45..b23eac6 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -75,20 +75,24 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: """ query = """ query GetDataset($id: UUID!) { - dataset(id: $id) { + getDataset(datasetId: $id) { id title description status - accessType - license + datasetType created - updated + modified + downloadCount organization { id name description } + user { + id + name + } tags { id value @@ -109,7 +113,12 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: format size } - schema + schema { + id + fieldName + format + description + } } } } @@ -128,7 +137,7 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("dataset", {}) + result: Dict[str, Any] = response.get("data", {}).get("getDataset", {}) return result def list_all( @@ -293,21 +302,24 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: """ query = """ query GetPromptDataset($id: UUID!) { - get_dataset(id: $id) { + getDataset(datasetId: $id) { id title description status - accessType - license datasetType created - updated + modified + downloadCount organization { id name description } + user { + id + name + } tags { id value @@ -352,7 +364,7 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("get_dataset", {}) + result: Dict[str, Any] = response.get("data", {}).get("getDataset", {}) return result def list_prompts( From b486be6be70e629dbf8ef3d9f919ac33490da248 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 19:33:13 +0000 Subject: [PATCH 081/127] Bump SDK version to 0.4.12 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 465fa21..9e46a1c 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.11" +__version__ = "0.4.12" diff --git a/pyproject.toml b/pyproject.toml index d8bcceb..6e72108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.11" +version = "0.4.12" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.11" +python_version = "0.4.12" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From fe310753baaab8b477121d44685822c7c2a673a5 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 15 Jan 2026 01:05:40 +0530 Subject: [PATCH 082/127] fix tests --- tests/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 29ec3f8..7465397 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -37,7 +37,7 @@ def test_search_datasets(self, mock_request: MagicMock) -> None: @patch.object(DatasetClient, "_make_request") def test_get_dataset_by_id(self, mock_request: MagicMock) -> None: """Test get dataset by ID.""" - mock_request.return_value = {"data": {"dataset": {"id": "123", "title": "Test Dataset"}}} + mock_request.return_value = {"data": {"getDataset": {"id": "123", "title": "Test Dataset"}}} result = self.client.get_by_id("123") @@ -94,7 +94,7 @@ def test_get_dataset_with_resources(self, mock_request: MagicMock) -> None: """Test get dataset by ID which includes resources.""" mock_request.return_value = { "data": { - "dataset": { + "getDataset": { "id": "dataset-123", "title": "Test Dataset", "resources": [ From 5e083c305ee1e66e43fac8f588521fa0d01b92a7 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 15 Jan 2026 02:13:29 +0530 Subject: [PATCH 083/127] remove username from queries --- dataspace_sdk/resources/datasets.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index b23eac6..24dee57 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -91,7 +91,6 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: } user { id - name } tags { id @@ -318,7 +317,6 @@ def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: } user { id - name } tags { id From 2b86b044fb8b4632057ac03d1791ac23168f4970 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 20:44:23 +0000 Subject: [PATCH 084/127] Bump SDK version to 0.4.13 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 9e46a1c..127e875 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.12" +__version__ = "0.4.13" diff --git a/pyproject.toml b/pyproject.toml index 6e72108..698730b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.12" +version = "0.4.13" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.12" +python_version = "0.4.13" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 60131ff4aae7183dbbf2718637e2911cfa9d64f8 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 19 Jan 2026 15:41:46 +0530 Subject: [PATCH 085/127] add more org details on userinfo api --- api/views/auth.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/views/auth.py b/api/views/auth.py index cc3e769..80291b1 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -96,6 +96,12 @@ def get(self, request: Request) -> Response: "id": org.organization.id, # type: ignore[attr-defined] "name": org.organization.name, # type: ignore[attr-defined] "role": org.role.name, # type: ignore[attr-defined] + "description": org.organization.description, # type: ignore[attr-defined] + "logo": org.organization.logo, # type: ignore[attr-defined] + "is_active": org.organization.is_active, # type: ignore[attr-defined] + "is_public": org.organization.is_public, # type: ignore[attr-defined] + "created_at": org.organization.created_at, # type: ignore[attr-defined] + "updated_at": org.organization.updated_at, # type: ignore[attr-defined] } for org in user.organizationmembership_set.all() # type: ignore[union-attr, arg-type] ], From 7cf48a1ce8d655d25f77a38403641d38879deeda Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 19 Jan 2026 15:57:46 +0530 Subject: [PATCH 086/127] remove invalid attributes --- api/views/auth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/views/auth.py b/api/views/auth.py index 80291b1..8b3d981 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -98,8 +98,7 @@ def get(self, request: Request) -> Response: "role": org.role.name, # type: ignore[attr-defined] "description": org.organization.description, # type: ignore[attr-defined] "logo": org.organization.logo, # type: ignore[attr-defined] - "is_active": org.organization.is_active, # type: ignore[attr-defined] - "is_public": org.organization.is_public, # type: ignore[attr-defined] + "homepage": org.organization.homepage, # type: ignore[attr-defined] "created_at": org.organization.created_at, # type: ignore[attr-defined] "updated_at": org.organization.updated_at, # type: ignore[attr-defined] } From 0065b0c0a73cc67f05343da3bf775e26bbb69766 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 19 Jan 2026 15:59:24 +0530 Subject: [PATCH 087/127] correct typo in attributes --- api/views/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/views/auth.py b/api/views/auth.py index 8b3d981..b42982d 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -99,8 +99,8 @@ def get(self, request: Request) -> Response: "description": org.organization.description, # type: ignore[attr-defined] "logo": org.organization.logo, # type: ignore[attr-defined] "homepage": org.organization.homepage, # type: ignore[attr-defined] - "created_at": org.organization.created_at, # type: ignore[attr-defined] - "updated_at": org.organization.updated_at, # type: ignore[attr-defined] + "created": org.organization.created, # type: ignore[attr-defined] + "updated": org.organization.modified, # type: ignore[attr-defined] } for org in user.organizationmembership_set.all() # type: ignore[union-attr, arg-type] ], From 3ee95eac72a0d8486f1f5fcabbd233ddb9e54d01 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 19 Jan 2026 16:03:09 +0530 Subject: [PATCH 088/127] fix org logo field return in user info api --- api/views/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/views/auth.py b/api/views/auth.py index b42982d..ef050d0 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -97,7 +97,7 @@ def get(self, request: Request) -> Response: "name": org.organization.name, # type: ignore[attr-defined] "role": org.role.name, # type: ignore[attr-defined] "description": org.organization.description, # type: ignore[attr-defined] - "logo": org.organization.logo, # type: ignore[attr-defined] + "logo": org.organization.logo.url if org.organization.logo else None, # type: ignore[attr-defined] "homepage": org.organization.homepage, # type: ignore[attr-defined] "created": org.organization.created, # type: ignore[attr-defined] "updated": org.organization.modified, # type: ignore[attr-defined] From 9bd86a333c38dc01f6396182e2bccad14b5349b5 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 19 Jan 2026 19:41:16 +0530 Subject: [PATCH 089/127] add support to call model with a specific version --- api/views/aimodel_execution.py | 18 ++++++++++++---- dataspace_sdk/resources/aimodels.py | 32 +++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/api/views/aimodel_execution.py b/api/views/aimodel_execution.py index b5dc0d8..b13ca42 100644 --- a/api/views/aimodel_execution.py +++ b/api/views/aimodel_execution.py @@ -65,11 +65,21 @@ def call_aimodel(request: Request, model_id: str) -> Response: ) parameters = request.data.get("parameters", {}) + version_id = request.data.get("version_id") - # Get the primary version and provider - primary_version = model.versions.filter(is_latest=True).first() - if not primary_version: - primary_version = model.versions.first() + # Get the version - either specific version or primary (latest) + if version_id: + primary_version = model.versions.filter(id=version_id).first() + if not primary_version: + return Response( + {"error": f"Version with ID {version_id} not found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + # Fall back to primary (latest) version + primary_version = model.versions.filter(is_latest=True).first() + if not primary_version: + primary_version = model.versions.first() if not primary_version: return Response( diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 93b04c1..bcf95b3 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -337,7 +337,11 @@ def delete_model(self, model_id: str) -> Dict[str, Any]: return self.delete(f"/api/aimodels/{model_id}/") def call_model( - self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None + self, + model_id: str, + input_text: str, + parameters: Optional[Dict[str, Any]] = None, + version_id: Optional[int] = None, ) -> Dict[str, Any]: """ Call an AI model with input text using the appropriate client (API or HuggingFace). @@ -346,6 +350,7 @@ def call_model( model_id: UUID of the AI model input_text: Input text to process parameters: Optional parameters for the model call (temperature, max_tokens, etc.) + version_id: Optional specific version ID to call (defaults to primary/latest version) Returns: Dictionary containing model response: @@ -358,13 +363,24 @@ def call_model( ... } """ + payload: Dict[str, Any] = { + "input_text": input_text, + "parameters": parameters or {}, + } + if version_id is not None: + payload["version_id"] = version_id + return self.post( f"/api/aimodels/{model_id}/call/", - json_data={"input_text": input_text, "parameters": parameters or {}}, + json_data=payload, ) def call_model_async( - self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None + self, + model_id: str, + input_text: str, + parameters: Optional[Dict[str, Any]] = None, + version_id: Optional[int] = None, ) -> Dict[str, Any]: """ Call an AI model asynchronously (returns task ID for long-running operations). @@ -373,6 +389,7 @@ def call_model_async( model_id: UUID of the AI model input_text: Input text to process parameters: Optional parameters for the model call + version_id: Optional specific version ID to call (defaults to primary/latest version) Returns: Dictionary containing task information: @@ -382,9 +399,16 @@ def call_model_async( "model_id": str } """ + payload: Dict[str, Any] = { + "input_text": input_text, + "parameters": parameters or {}, + } + if version_id is not None: + payload["version_id"] = version_id + return self.post( f"/api/aimodels/{model_id}/call-async/", - json_data={"input_text": input_text, "parameters": parameters or {}}, + json_data=payload, ) # ==================== Version Management ==================== From 4bdffdfcfd5b9bbdaacf4eafb141165d6ef159a8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 19 Jan 2026 14:13:09 +0000 Subject: [PATCH 090/127] Bump SDK version to 0.4.14 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 127e875..8db8788 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.13" +__version__ = "0.4.14" diff --git a/pyproject.toml b/pyproject.toml index 698730b..5d95ce4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.13" +version = "0.4.14" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.13" +python_version = "0.4.14" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 062299f0e92e920a4161d45382fe52f40f083f5b Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 22 Jan 2026 11:53:00 +0530 Subject: [PATCH 091/127] add auditer management apis and auditor client in sdk --- api/urls.py | 12 ++ api/views/auditor.py | 315 ++++++++++++++++++++++++++++ dataspace_sdk/client.py | 2 + dataspace_sdk/resources/__init__.py | 3 +- dataspace_sdk/resources/auditors.py | 264 +++++++++++++++++++++++ 5 files changed, 595 insertions(+), 1 deletion(-) create mode 100644 api/views/auditor.py create mode 100644 dataspace_sdk/resources/auditors.py diff --git a/api/urls.py b/api/urls.py index 98a80a8..86cbb9e 100644 --- a/api/urls.py +++ b/api/urls.py @@ -9,6 +9,7 @@ from api.views import ( aimodel_detail, aimodel_execution, + auditor, auth, download, generate_dynamic_chart, @@ -32,6 +33,17 @@ path("auth/keycloak/login/", auth.KeycloakLoginView.as_view(), name="keycloak_login"), path("auth/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"), path("auth/user/info/", auth.UserInfoView.as_view(), name="user_info"), + # Auditor management endpoints + path( + "organizations//auditors/", + auditor.OrganizationAuditorsView.as_view(), + name="organization_auditors", + ), + path( + "users/search-by-email/", + auditor.SearchUserByEmailView.as_view(), + name="search_user_by_email", + ), # API endpoints path("search/dataset/", search_dataset.SearchDataset.as_view(), name="search_dataset"), path("search/usecase/", search_usecase.SearchUseCase.as_view(), name="search_usecase"), diff --git a/api/views/auditor.py b/api/views/auditor.py new file mode 100644 index 0000000..e3ba4f3 --- /dev/null +++ b/api/views/auditor.py @@ -0,0 +1,315 @@ +"""REST API views for auditor management.""" + +import logging +from typing import Any, Dict, List, Optional + +from django.db import transaction +from rest_framework import status, views +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.response import Response + +from api.models import Organization +from authorization.models import OrganizationMembership, Role, User + +logger = logging.getLogger(__name__) + + +class OrganizationAuditorsView(views.APIView): + """ + View for managing auditors in an organization. + + GET: List all auditors for an organization + POST: Add a user as auditor to an organization (by user_id or email) + DELETE: Remove an auditor from an organization + """ + + permission_classes = [IsAuthenticated] + + def _get_organization(self, organization_id: str) -> Optional[Organization]: + """Get organization by ID.""" + try: + return Organization.objects.get(id=organization_id) + except Organization.DoesNotExist: + return None + + def _check_admin_permission(self, user: User, organization: Organization) -> bool: + """Check if user has admin permission for the organization.""" + if user.is_superuser: + return True + try: + membership = OrganizationMembership.objects.get(user=user, organization=organization) + # Admin role has can_change permission + return membership.role.can_change # type: ignore[return-value] + except OrganizationMembership.DoesNotExist: + return False + + def _get_auditor_role(self) -> Optional[Role]: + """Get the auditor role.""" + try: + return Role.objects.get(name="auditor") + except Role.DoesNotExist: + logger.error("Auditor role not found. Please run migrations.") + return None + + def get(self, request: Request, organization_id: str) -> Response: + """Get all auditors for an organization.""" + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has permission to view organization members + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to view auditors for this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + # Get all auditors for this organization + auditor_memberships = OrganizationMembership.objects.filter( + organization=organization, role=auditor_role + ).select_related("user") + + auditors: List[Dict[str, Any]] = [] + for membership in auditor_memberships: # type: OrganizationMembership + user: User = membership.user # type: ignore[assignment] + auditors.append( + { + "id": str(user.id), + "username": user.username, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "profile_picture": user.profile_picture.url if user.profile_picture else None, + "joined_at": ( + membership.created_at.isoformat() if membership.created_at else None + ), + } + ) + + return Response( + { + "organization_id": str(organization.id), + "organization_name": organization.name, + "auditors": auditors, + "count": len(auditors), + } + ) + + @transaction.atomic + def post(self, request: Request, organization_id: str) -> Response: + """ + Add a user as auditor to an organization. + + Request body can contain either: + - user_id: ID of an existing user + - email: Email of a user to add (will look up user by email) + """ + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has admin permission + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to add auditors to this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + user_id = request.data.get("user_id") + email = request.data.get("email") + + if not user_id and not email: + return Response( + {"error": "Either user_id or email is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Find the user + target_user: Optional[User] = None + if user_id: + try: + target_user = User.objects.get(id=user_id) + except User.DoesNotExist: + return Response( + {"error": f"User with ID {user_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + elif email: + try: + target_user = User.objects.get(email=email) + except User.DoesNotExist: + return Response( + { + "error": f"User with email {email} not found. The user must have an account in CivicDataSpace first." + }, + status=status.HTTP_404_NOT_FOUND, + ) + + if not target_user: + return Response( + {"error": "Could not find user"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user is already a member of the organization + existing_membership = OrganizationMembership.objects.filter( + user=target_user, organization=organization + ).first() + + if existing_membership: + if existing_membership.role == auditor_role: + return Response( + {"error": "User is already an auditor for this organization"}, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + # User has a different role, update to auditor + # Note: This might not be desired behavior - you may want to keep existing role + # For now, we'll add them as auditor (they can have multiple roles in future) + return Response( + {"error": f"User is already a member of this organization with role '{existing_membership.role.name}'"}, # type: ignore[attr-defined] + status=status.HTTP_400_BAD_REQUEST, + ) + + # Create the membership + membership = OrganizationMembership.objects.create( + user=target_user, + organization=organization, + role=auditor_role, + ) + + logger.info( + f"Added user {target_user.username} as auditor to organization {organization.name}" + ) + + return Response( + { + "success": True, + "message": f"User {target_user.username} added as auditor", + "auditor": { + "id": target_user.id, + "username": target_user.username, + "email": target_user.email, + "first_name": target_user.first_name, + "last_name": target_user.last_name, + "joined_at": membership.created_at.isoformat(), + }, + }, + status=status.HTTP_201_CREATED, + ) + + @transaction.atomic + def delete(self, request: Request, organization_id: str) -> Response: + """Remove an auditor from an organization.""" + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has admin permission + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to remove auditors from this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + user_id = request.data.get("user_id") + if not user_id: + return Response( + {"error": "user_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + try: + membership = OrganizationMembership.objects.get( + user_id=user_id, organization=organization, role=auditor_role + ) + username = membership.user.username # type: ignore[attr-defined] + membership.delete() + + logger.info(f"Removed auditor {username} from organization {organization.name}") + + return Response( + { + "success": True, + "message": f"Auditor {username} removed from organization", + } + ) + except OrganizationMembership.DoesNotExist: + return Response( + {"error": "User is not an auditor for this organization"}, + status=status.HTTP_404_NOT_FOUND, + ) + + +class SearchUserByEmailView(views.APIView): + """ + Search for a user by email. + Used to find users before adding them as auditors. + """ + + permission_classes = [IsAuthenticated] + + def get(self, request: Request) -> Response: + """Search for a user by email.""" + email = request.query_params.get("email") + if not email: + return Response( + {"error": "email query parameter is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + user = User.objects.get(email=email) + return Response( + { + "found": True, + "user": { + "id": user.id, + "username": user.username, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "profile_picture": ( + user.profile_picture.url if user.profile_picture else None + ), + }, + } + ) + except User.DoesNotExist: + return Response( + { + "found": False, + "message": f"No user found with email {email}", + } + ) diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index 5d1dfce..d59c90b 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -4,6 +4,7 @@ from dataspace_sdk.auth import AuthClient from dataspace_sdk.resources.aimodels import AIModelClient +from dataspace_sdk.resources.auditors import AuditorClient from dataspace_sdk.resources.datasets import DatasetClient from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient @@ -66,6 +67,7 @@ def __init__( self.aimodels = AIModelClient(self.base_url, self._auth) self.usecases = UseCaseClient(self.base_url, self._auth) self.sectors = SectorClient(self.base_url, self._auth) + self.auditors = AuditorClient(self.base_url, self._auth) def login(self, username: str, password: str) -> dict: """ diff --git a/dataspace_sdk/resources/__init__.py b/dataspace_sdk/resources/__init__.py index aa742c7..e0f6472 100644 --- a/dataspace_sdk/resources/__init__.py +++ b/dataspace_sdk/resources/__init__.py @@ -1,8 +1,9 @@ """Resource clients for DataSpace SDK.""" from dataspace_sdk.resources.aimodels import AIModelClient +from dataspace_sdk.resources.auditors import AuditorClient from dataspace_sdk.resources.datasets import DatasetClient from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient -__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient", "SectorClient"] +__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient", "SectorClient", "AuditorClient"] diff --git a/dataspace_sdk/resources/auditors.py b/dataspace_sdk/resources/auditors.py new file mode 100644 index 0000000..2a521a4 --- /dev/null +++ b/dataspace_sdk/resources/auditors.py @@ -0,0 +1,264 @@ +"""Auditor management resource for DataSpace SDK.""" + +from typing import Any, Dict, Optional + +import requests + +from dataspace_sdk.exceptions import DataSpaceAuthError + + +class AuditorClient: + """ + Client for managing auditors in organizations. + + Auditors are users with the 'auditor' role in an organization, + who can audit AI models registered by that organization. + """ + + def __init__(self, base_url: str, auth_client: Any): + """ + Initialize the auditor client. + + Args: + base_url: Base URL of the DataSpace API + auth_client: Authentication client instance + """ + self._base_url = base_url.rstrip("/") + self._auth = auth_client + + def _get_headers(self) -> Dict[str, str]: + """Get request headers with authentication.""" + headers = {"Content-Type": "application/json"} + if self._auth and self._auth.access_token: + headers["Authorization"] = f"Bearer {self._auth.access_token}" + return headers + + def get_organization_auditors(self, organization_id: str) -> Dict[str, Any]: + """ + Get all auditors for an organization. + + Args: + organization_id: UUID of the organization + + Returns: + Dictionary containing: + - organization_id: str + - organization_name: str + - auditors: List of auditor dictionaries + - count: int + + Raises: + DataSpaceAuthError: If not authenticated or permission denied + + Example: + >>> result = client.auditors.get_organization_auditors("org-uuid") + >>> for auditor in result["auditors"]: + ... print(f"{auditor['username']} - {auditor['email']}") + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + response = requests.get( + url, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + raise DataSpaceAuthError(f"Organization {organization_id} not found") + else: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to get auditors"), + status_code=response.status_code, + response=error_data, + ) + + def add_auditor( + self, + organization_id: str, + user_id: Optional[str] = None, + email: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Add a user as auditor to an organization. + + You must provide either user_id or email. If the user is found, + they will be added as an auditor to the organization. + + Args: + organization_id: UUID of the organization + user_id: Optional user ID to add as auditor + email: Optional email to look up user and add as auditor + + Returns: + Dictionary containing: + - success: bool + - message: str + - auditor: Dictionary with auditor details + + Raises: + DataSpaceAuthError: If not authenticated, permission denied, or user not found + ValueError: If neither user_id nor email is provided + + Example: + >>> # Add by user ID + >>> result = client.auditors.add_auditor("org-uuid", user_id="user-uuid") + >>> + >>> # Add by email + >>> result = client.auditors.add_auditor("org-uuid", email="auditor@example.com") + """ + if not user_id and not email: + raise ValueError("Either user_id or email must be provided") + + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + payload: Dict[str, str] = {} + if user_id: + payload["user_id"] = user_id + if email: + payload["email"] = email + + response = requests.post( + url, + json=payload, + headers=self._get_headers(), + ) + + if response.status_code == 201: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Organization or user not found"), + status_code=response.status_code, + response=error_data, + ) + elif response.status_code == 400: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Invalid request"), + status_code=response.status_code, + response=error_data, + ) + else: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to add auditor"), + status_code=response.status_code, + response=error_data, + ) + + def remove_auditor(self, organization_id: str, user_id: str) -> Dict[str, Any]: + """ + Remove an auditor from an organization. + + Args: + organization_id: UUID of the organization + user_id: ID of the user to remove as auditor + + Returns: + Dictionary containing: + - success: bool + - message: str + + Raises: + DataSpaceAuthError: If not authenticated, permission denied, or user not an auditor + + Example: + >>> result = client.auditors.remove_auditor("org-uuid", "user-uuid") + >>> print(result["message"]) + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + response = requests.delete( + url, + json={"user_id": user_id}, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Organization or auditor not found"), + status_code=response.status_code, + response=error_data, + ) + else: + error_data_other: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data_other.get("error", "Failed to remove auditor"), + status_code=response.status_code, + response=error_data_other, + ) + + def search_user_by_email(self, email: str) -> Dict[str, Any]: + """ + Search for a user by email address. + + This is useful to check if a user exists before adding them as an auditor. + + Args: + email: Email address to search for + + Returns: + Dictionary containing: + - found: bool + - user: Dictionary with user details (if found) + - message: str (if not found) + + Raises: + DataSpaceAuthError: If not authenticated + + Example: + >>> result = client.auditors.search_user_by_email("user@example.com") + >>> if result["found"]: + ... print(f"Found user: {result['user']['username']}") + ... else: + ... print("User not found") + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/users/search-by-email/" + + response = requests.get( + url, + params={"email": email}, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + else: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to search user"), + status_code=response.status_code, + response=error_data, + ) From ebb7104e4ed605751194636e341e0d60f2a496ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 22 Jan 2026 06:24:35 +0000 Subject: [PATCH 092/127] Bump SDK version to 0.4.15 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index 8db8788..df7ee54 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.14" +__version__ = "0.4.15" diff --git a/pyproject.toml b/pyproject.toml index 5d95ce4..fdc75ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.14" +version = "0.4.15" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.14" +python_version = "0.4.15" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From c88e806b4fcbbf45cea795ecf6160bb8d0806e46 Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 30 Jan 2026 16:59:17 +0530 Subject: [PATCH 093/127] add whitelisting feature for ratelimiting --- api/middleware/rate_limit.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/api/middleware/rate_limit.py b/api/middleware/rate_limit.py index 83cd605..e53bf83 100644 --- a/api/middleware/rate_limit.py +++ b/api/middleware/rate_limit.py @@ -1,7 +1,9 @@ import logging +import os import time from typing import Any, Callable, Optional, cast +from django.conf import settings from django.core.cache import cache from django.http import HttpRequest, HttpResponse from redis.exceptions import RedisError @@ -9,12 +11,32 @@ logger = logging.getLogger(__name__) +def get_whitelisted_ips() -> set[str]: + """Get the set of whitelisted IPs from settings or environment variable. + + Can be configured via: + - RATE_LIMIT_WHITELIST_IPS in Django settings (list) + - RATE_LIMIT_WHITELIST_IPS environment variable (comma-separated string) + """ + # First check Django settings + whitelist = getattr(settings, "RATE_LIMIT_WHITELIST_IPS", None) + if whitelist: + return set(whitelist) + + # Fall back to environment variable + env_whitelist = os.environ.get("RATE_LIMIT_WHITELIST_IPS", "") + if env_whitelist: + return {ip.strip() for ip in env_whitelist.split(",") if ip.strip()} + + return set() + + class HttpResponseTooManyRequests(HttpResponse): status_code = 429 def rate_limit_middleware( - get_response: Callable[[HttpRequest], HttpResponse] + get_response: Callable[[HttpRequest], HttpResponse], ) -> Callable[[HttpRequest], HttpResponse]: """Rate limiting middleware that uses a simple cache-based counter.""" @@ -78,10 +100,18 @@ def check_rate_limit(request: HttpRequest) -> bool: return True # Allow request on unexpected error def middleware(request: HttpRequest) -> HttpResponse: + client_ip = get_client_ip(request) + whitelisted_ips = get_whitelisted_ips() + + # Skip rate limiting for whitelisted IPs + if client_ip in whitelisted_ips: + logger.debug(f"Skipping rate limit for whitelisted IP: {client_ip}") + return get_response(request) + if not check_rate_limit(request): logger.warning( f"Rate limited - Method: {request.method}, " - f"Path: {request.path}, IP: {get_client_ip(request)}" + f"Path: {request.path}, IP: {client_ip}" ) return HttpResponseTooManyRequests() From b9467ae814e6a5beb279553aa7e8e0816b51553c Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 6 Feb 2026 18:33:34 +0530 Subject: [PATCH 094/127] make json fields in aimodel nullable --- api/models/AIModel.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/models/AIModel.py b/api/models/AIModel.py index a34641d..ff0fd5e 100644 --- a/api/models/AIModel.py +++ b/api/models/AIModel.py @@ -92,12 +92,18 @@ class AIModel(models.Model): ) supported_languages = models.JSONField( default=list, + blank=True, + null=True, help_text="List of supported language codes (e.g., ['en', 'es', 'fr'])", ) # Input/Output Schema - input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters") - output_schema = models.JSONField(default=dict, help_text="Expected output format") + input_schema = models.JSONField( + default=dict, blank=True, null=True, help_text="Expected input format and parameters" + ) + output_schema = models.JSONField( + default=dict, blank=True, null=True, help_text="Expected output format" + ) # Metadata tags = models.ManyToManyField("api.Tag", blank=True) @@ -105,6 +111,8 @@ class AIModel(models.Model): geographies = models.ManyToManyField("api.Geography", blank=True, related_name="ai_models") metadata = models.JSONField( default=dict, + blank=True, + null=True, help_text="Additional metadata (training data info, limitations, etc.)", ) From 8d54f974df2948ea3a7b50046520c501e7fc8bff Mon Sep 17 00:00:00 2001 From: dc Date: Fri, 6 Feb 2026 18:33:34 +0530 Subject: [PATCH 095/127] make json fields in aimodel nullable --- api/models/AIModel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/models/AIModel.py b/api/models/AIModel.py index ff0fd5e..1a4b3a4 100644 --- a/api/models/AIModel.py +++ b/api/models/AIModel.py @@ -82,8 +82,6 @@ class AIModel(models.Model): blank=True, related_name="ai_models", ) - # API Configuration - # Endpoints are stored in separate ModelEndpoint table for flexibility # Model Capabilities supports_streaming = models.BooleanField(default=False) From be13c1391c95c3534396310c2f8fa9e5683acda9 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 9 Feb 2026 14:41:31 +0530 Subject: [PATCH 096/127] make nullable fields optional in type definitions --- api/types/type_aimodel.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 7adf63e..6e94dfc 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -120,10 +120,10 @@ class TypeAIModel(BaseType): user: Optional["TypeUser"] supports_streaming: bool max_tokens: Optional[int] - supported_languages: strawberry.scalars.JSON - input_schema: strawberry.scalars.JSON - output_schema: strawberry.scalars.JSON - metadata: strawberry.scalars.JSON + supported_languages: Optional[strawberry.scalars.JSON] + input_schema: Optional[strawberry.scalars.JSON] + output_schema: Optional[strawberry.scalars.JSON] + metadata: Optional[strawberry.scalars.JSON] status: AIModelStatusEnum is_public: bool is_active: bool @@ -287,10 +287,10 @@ class TypeAIModelVersion(BaseType): version_notes: Optional[str] supports_streaming: bool max_tokens: Optional[int] - supported_languages: strawberry.scalars.JSON - input_schema: strawberry.scalars.JSON - output_schema: strawberry.scalars.JSON - metadata: strawberry.scalars.JSON + supported_languages: Optional[strawberry.scalars.JSON] + input_schema: Optional[strawberry.scalars.JSON] + output_schema: Optional[strawberry.scalars.JSON] + metadata: Optional[strawberry.scalars.JSON] status: AIModelStatusEnum lifecycle_stage: AIModelLifecycleStageEnum is_latest: bool From 6e308d874191c1f8d868f28659294330de82df70 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 10 Feb 2026 13:50:03 +0530 Subject: [PATCH 097/127] dont return public models if user is authenticated --- api/schema/aimodel_schema.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index 786ecd0..c57f4c1 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -308,10 +308,8 @@ def ai_models( if user.is_superuser: queryset = AIModel.objects.all() else: - # For authenticated users, show their models and public models - queryset = AIModel.objects.filter(user=user) | AIModel.objects.filter( - is_public=True, is_active=True - ) + # For authenticated users, show their models + queryset = AIModel.objects.filter(user=user, organization=None) else: # For non-authenticated users, only show public active models queryset = AIModel.objects.filter(is_public=True, is_active=True) From ebc2b3928d1b1cb42882c59a7756aa6161177845 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 10 Feb 2026 15:18:27 +0530 Subject: [PATCH 098/127] add unique number to slug to avoid conflicts --- api/models/Resource.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/api/models/Resource.py b/api/models/Resource.py index 3f5e5be..35562f2 100644 --- a/api/models/Resource.py +++ b/api/models/Resource.py @@ -56,7 +56,14 @@ class Resource(models.Model): version = models.CharField(max_length=50, default="v1.0") def save(self, *args: Any, **kwargs: Any) -> None: - self.slug = slugify(self.name) + if not self.slug: + base_slug = slugify(self.name) + slug = base_slug + counter = 1 + while Resource.objects.filter(slug=slug).exclude(pk=self.pk).exists(): + slug = f"{base_slug}-{counter}" + counter += 1 + self.slug = slug super().save(*args, **kwargs) def __str__(self) -> str: @@ -64,9 +71,7 @@ def __str__(self) -> str: class ResourceFileDetails(models.Model): - resource = models.OneToOneField( - Resource, on_delete=models.CASCADE, null=False, blank=False - ) + resource = models.OneToOneField(Resource, on_delete=models.CASCADE, null=False, blank=False) file = models.FileField(upload_to="resources/", max_length=300) size = models.FloatField(blank=True, null=True) created = models.DateTimeField(auto_now_add=True) @@ -84,9 +89,7 @@ class ResourceDataTable(models.Model): """Model to store indexed CSV data for a resource.""" id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - resource = models.OneToOneField( - Resource, on_delete=models.CASCADE, null=False, blank=False - ) + resource = models.OneToOneField(Resource, on_delete=models.CASCADE, null=False, blank=False) table_name = models.CharField(max_length=255, unique=True) created = models.DateTimeField(auto_now_add=True) modified = models.DateTimeField(auto_now=True) @@ -109,9 +112,7 @@ def save(self, *args, **kwargs): class ResourceVersion(models.Model): - resource = models.ForeignKey( - Resource, on_delete=models.CASCADE, related_name="versions" - ) + resource = models.ForeignKey(Resource, on_delete=models.CASCADE, related_name="versions") version_number = models.CharField(max_length=50) commit_hash = models.CharField(max_length=64, null=True) created_at = models.DateTimeField(auto_now_add=True) @@ -167,9 +168,9 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** return # Get the latest version - last_version: Optional[ResourceVersion] = ( - instance.resource.versions.order_by("-created_at").first() - ) + last_version: Optional[ResourceVersion] = instance.resource.versions.order_by( + "-created_at" + ).first() # Handle case when there are no versions yet if last_version is None: @@ -193,9 +194,7 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** # Use DVC to get the previous version try: # Try to checkout the previous version using DVC - rel_path = Path(instance.file.path).relative_to( - settings.DVC_REPO_PATH - ) + rel_path = Path(instance.file.path).relative_to(settings.DVC_REPO_PATH) tag_name = f"{instance.resource.name}-{last_version.version_number}" # Save current file to temp location @@ -245,9 +244,7 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** # Update using DVC dvc_file = dvc.track_resource(instance.file.path, chunked=use_chunked) - message = ( - f"Update resource: {instance.resource.name} to version {new_version}" - ) + message = f"Update resource: {instance.resource.name} to version {new_version}" dvc.commit_version(dvc_file, message) dvc.tag_version(f"{instance.resource.name}-{new_version}") From b36245fa8082da23f6415761ec56ad7303f419f2 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 10 Feb 2026 16:27:27 +0530 Subject: [PATCH 099/127] update sdk client to accept headers --- dataspace_sdk/base.py | 4 ++++ dataspace_sdk/client.py | 16 ++++++++++++++++ dataspace_sdk/resources/auditors.py | 3 +++ 3 files changed, 23 insertions(+) diff --git a/dataspace_sdk/base.py b/dataspace_sdk/base.py index bdf38e1..9f53e2d 100644 --- a/dataspace_sdk/base.py +++ b/dataspace_sdk/base.py @@ -25,6 +25,7 @@ def __init__(self, base_url: str, auth_client: Any = None): """ self.base_url = base_url.rstrip("/") self.auth_client = auth_client + self.default_headers: Dict[str, str] = {} def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: """ @@ -38,6 +39,9 @@ def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> D """ headers = {"Content-Type": "application/json"} + if self.default_headers: + headers.update(self.default_headers) + if self.auth_client and self.auth_client.is_authenticated(): headers["Authorization"] = f"Bearer {self.auth_client.access_token}" diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index d59c90b..5f85eb2 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -187,6 +187,22 @@ def is_authenticated(self) -> bool: """ return self._auth.is_authenticated() + def set_organization(self, organization_id: str) -> None: + """ + Set the organization header for all subsequent API requests. + + The DataSpace backend reads the 'organization' header to scope + queries to a specific organization. + + Args: + organization_id: Organization ID to include in requests + """ + self.datasets.default_headers["organization"] = organization_id + self.aimodels.default_headers["organization"] = organization_id + self.usecases.default_headers["organization"] = organization_id + self.sectors.default_headers["organization"] = organization_id + self.auditors.default_headers["organization"] = organization_id + @property def user(self) -> Optional[dict]: """ diff --git a/dataspace_sdk/resources/auditors.py b/dataspace_sdk/resources/auditors.py index 2a521a4..6fbd547 100644 --- a/dataspace_sdk/resources/auditors.py +++ b/dataspace_sdk/resources/auditors.py @@ -25,10 +25,13 @@ def __init__(self, base_url: str, auth_client: Any): """ self._base_url = base_url.rstrip("/") self._auth = auth_client + self.default_headers: Dict[str, str] = {} def _get_headers(self) -> Dict[str, str]: """Get request headers with authentication.""" headers = {"Content-Type": "application/json"} + if self.default_headers: + headers.update(self.default_headers) if self._auth and self._auth.access_token: headers["Authorization"] = f"Bearer {self._auth.access_token}" return headers From d92133275f9da05cbce05cc4bfee87ad94e9a003 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 10 Feb 2026 11:03:38 +0000 Subject: [PATCH 100/127] Bump SDK version to 0.4.16 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index df7ee54..b1b5b52 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.15" +__version__ = "0.4.16" diff --git a/pyproject.toml b/pyproject.toml index fdc75ea..61d9932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.15" +version = "0.4.16" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.15" +python_version = "0.4.16" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From e7f45a11c9f4b1247c66a243189359c3ab4d9e9d Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 10 Feb 2026 16:52:36 +0530 Subject: [PATCH 101/127] add fallback to fetch org by id if slug not present --- api/utils/middleware.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/api/utils/middleware.py b/api/utils/middleware.py index 145a592..c3f5e81 100644 --- a/api/utils/middleware.py +++ b/api/utils/middleware.py @@ -40,9 +40,7 @@ class CustomHttpRequest(HttpRequest): class ContextMiddleware: - def __init__( - self, get_response: Callable[[CustomHttpRequest], HttpResponse] - ) -> None: + def __init__(self, get_response: Callable[[CustomHttpRequest], HttpResponse]) -> None: self.get_response = get_response def __call__(self, request: CustomHttpRequest) -> HttpResponse: @@ -71,7 +69,10 @@ def __call__(self, request: CustomHttpRequest) -> HttpResponse: if organization_slug is None: organization: Optional[Organization] = None else: - organization = get_object_or_404(Organization, slug=organization_slug) + try: + organization = Organization.objects.get(slug=organization_slug) + except Organization.DoesNotExist: + organization = get_object_or_404(Organization, id=organization_slug) if dataspace_slug is None: dataspace: Optional[DataSpace] = None else: From 3ea4238f53c46b80bdd2cdb19f50222e06a071a0 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 10 Feb 2026 19:26:52 +0530 Subject: [PATCH 102/127] add domain to aimodel --- api/admin.py | 3 ++- api/models/AIModel.py | 8 ++++++++ api/schema/aimodel_schema.py | 8 +++++++- api/types/type_aimodel.py | 4 ++++ api/views/aimodel_detail.py | 5 +++-- api/views/search_aimodel.py | 3 +++ dataspace_sdk/resources/aimodels.py | 6 ++++++ search/documents/aimodel_document.py | 1 + 8 files changed, 34 insertions(+), 4 deletions(-) diff --git a/api/admin.py b/api/admin.py index 2ca2667..3ce529e 100644 --- a/api/admin.py +++ b/api/admin.py @@ -59,6 +59,7 @@ class AIModelAdmin(admin.ModelAdmin): list_filter = ( "provider", "model_type", + "domain", "status", "is_public", "is_active", @@ -85,7 +86,7 @@ class AIModelAdmin(admin.ModelAdmin): "Schema", {"fields": ("input_schema", "output_schema"), "classes": ("collapse",)}, ), - ("Metadata", {"fields": ("tags", "metadata"), "classes": ("collapse",)}), + ("Metadata", {"fields": ("tags", "domain", "metadata"), "classes": ("collapse",)}), ("Status & Visibility", {"fields": ("status", "is_public", "is_active")}), ( "Performance Metrics", diff --git a/api/models/AIModel.py b/api/models/AIModel.py index 1a4b3a4..9646e62 100644 --- a/api/models/AIModel.py +++ b/api/models/AIModel.py @@ -12,6 +12,7 @@ EndpointAuthType, EndpointHTTPMethod, HFModelClass, + PromptDomain, ) User = get_user_model() @@ -107,6 +108,13 @@ class AIModel(models.Model): tags = models.ManyToManyField("api.Tag", blank=True) sectors = models.ManyToManyField("api.Sector", blank=True, related_name="ai_models") geographies = models.ManyToManyField("api.Geography", blank=True, related_name="ai_models") + domain = models.CharField( + max_length=200, + choices=PromptDomain.choices, + blank=True, + null=True, + help_text="Domain or category (e.g., healthcare, education, legal)", + ) metadata = models.JSONField( default=dict, blank=True, diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index c57f4c1..f2075a2 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -1,6 +1,6 @@ """GraphQL schema for AI Model.""" -# mypy: disable-error-code=union-attr +# mypy: disable-error-code="union-attr,misc" import datetime from typing import List, Optional @@ -27,6 +27,7 @@ AIModelVersionOrder, EndpointAuthTypeEnum, EndpointHTTPMethodEnum, + PromptDomainEnum, TypeAIModel, TypeAIModelVersion, TypeModelEndpoint, @@ -95,6 +96,7 @@ class CreateAIModelInput: tags: Optional[List[str]] = None sectors: Optional[List[str]] = None geographies: Optional[List[int]] = None + domain: Optional[PromptDomainEnum] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: bool = False @@ -119,6 +121,7 @@ class UpdateAIModelInput: tags: Optional[List[str]] = None sectors: Optional[List[str]] = None geographies: Optional[List[int]] = None + domain: Optional[PromptDomainEnum] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: Optional[bool] = None is_active: Optional[bool] = None @@ -441,6 +444,7 @@ def create_ai_model( supported_languages=supported_languages, input_schema=input_schema, output_schema=output_schema, + domain=input.domain if input.domain else None, metadata=metadata, is_public=input.is_public, status="REGISTERED", @@ -518,6 +522,8 @@ def update_ai_model( model.input_schema = input.input_schema if input.output_schema is not None: model.output_schema = input.output_schema + if input.domain is not None: + model.domain = input.domain if input.metadata is not None: model.metadata = input.metadata if input.is_public is not None: diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 6e94dfc..7870d8c 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -26,6 +26,7 @@ EndpointAuthType, EndpointHTTPMethod, HFModelClass, + PromptDomain, ) from authorization.types import TypeUser @@ -41,6 +42,7 @@ AIModelFrameworkEnum = strawberry.enum(AIModelFramework) # type: ignore HFModelClassEnum = strawberry.enum(HFModelClass) # type: ignore AIModelLifecycleStageEnum = strawberry.enum(AIModelLifecycleStage) # type: ignore +PromptDomainEnum = strawberry.enum(PromptDomain) # type: ignore @strawberry.type @@ -83,6 +85,7 @@ class AIModelFilter: status: Optional[AIModelStatusEnum] model_type: Optional[AIModelTypeEnum] provider: Optional[AIModelProviderEnum] + domain: Optional[PromptDomainEnum] is_public: Optional[bool] is_active: Optional[bool] @@ -123,6 +126,7 @@ class TypeAIModel(BaseType): supported_languages: Optional[strawberry.scalars.JSON] input_schema: Optional[strawberry.scalars.JSON] output_schema: Optional[strawberry.scalars.JSON] + domain: Optional[PromptDomainEnum] metadata: Optional[strawberry.scalars.JSON] status: AIModelStatusEnum is_public: bool diff --git a/api/views/aimodel_detail.py b/api/views/aimodel_detail.py index dab6206..96612dd 100644 --- a/api/views/aimodel_detail.py +++ b/api/views/aimodel_detail.py @@ -1,8 +1,8 @@ """API view for AI Model detail.""" +import logging from typing import Any, Dict, List, Optional -import logging from rest_framework import serializers, status from rest_framework.permissions import AllowAny from rest_framework.request import Request @@ -11,9 +11,9 @@ from api.models.AIModel import AIModel, ModelEndpoint - logger = logging.getLogger(__name__) + class ModelEndpointSerializer(serializers.ModelSerializer): """Serializer for Model Endpoint.""" @@ -59,6 +59,7 @@ class Meta: "tags", "sectors", "geographies", + "domain", "metadata", "status", "is_public", diff --git a/api/views/search_aimodel.py b/api/views/search_aimodel.py index b7fc606..083cc65 100644 --- a/api/views/search_aimodel.py +++ b/api/views/search_aimodel.py @@ -103,6 +103,7 @@ class Meta: "is_individual_model", "has_active_endpoints", "endpoint_count", + "domain", "version_count", "lifecycle_stage", "all_providers", @@ -154,6 +155,7 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: "supports_streaming": "terms", "lifecycle_stage": "terms", "all_providers": "terms", + "domain": "terms", } return searchable_fields, aggregations @@ -208,6 +210,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: "supported_languages", "lifecycle_stage", "all_providers", + "domain", ]: # Handle single or multi-value filters filter_values = filters[filter_key].split(",") diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index bcf95b3..586d1c7 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -17,6 +17,7 @@ def search( status: Optional[str] = None, model_type: Optional[str] = None, provider: Optional[str] = None, + domain: Optional[str] = None, sort: Optional[str] = None, page: int = 1, page_size: int = 10, @@ -32,6 +33,7 @@ def search( status: Filter by status (ACTIVE, INACTIVE, etc.) model_type: Filter by model type (LLM, VISION, etc.) provider: Filter by provider (OPENAI, ANTHROPIC, etc.) + domain: Filter by domain (HEALTHCARE, EDUCATION, etc.) sort: Sort order (recent, alphabetical) page: Page number (1-indexed) page_size: Number of results per page @@ -58,6 +60,8 @@ def search( params["model_type"] = model_type if provider: params["provider"] = provider + if domain: + params["domain"] = domain if sort: params["sort"] = sort @@ -94,6 +98,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: displayName description modelType + domain status isPublic createdAt @@ -207,6 +212,7 @@ def list_all( displayName description modelType + domain status isPublic createdAt diff --git a/search/documents/aimodel_document.py b/search/documents/aimodel_document.py index 4d5ae95..fdeae8b 100644 --- a/search/documents/aimodel_document.py +++ b/search/documents/aimodel_document.py @@ -54,6 +54,7 @@ class AIModelDocument(Document): # Model configuration model_type = fields.KeywordField() + domain = fields.KeywordField() # Status and visibility status = fields.KeywordField() From 8ff0f43eeef0d66ba3dd159fbf9decf6048b87cd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 10 Feb 2026 14:04:11 +0000 Subject: [PATCH 103/127] Bump SDK version to 0.4.17 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index b1b5b52..b33e846 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.16" +__version__ = "0.4.17" diff --git a/pyproject.toml b/pyproject.toml index 61d9932..9058bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.16" +version = "0.4.17" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.16" +python_version = "0.4.17" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From f593d479b73446db4204174db46c3ab4e291b10a Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 11 Feb 2026 11:52:43 +0530 Subject: [PATCH 104/127] fix aimodel endpoint in sdk --- dataspace_sdk/resources/aimodels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 586d1c7..9dd10c9 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -91,8 +91,8 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: Dictionary containing AI model information """ query = """ - query GetAIModel($id: UUID!) { - aiModel(id: $id) { + query GetAIModel($id: Int!) { + getAiModel(modelId: $id) { id name displayName @@ -171,7 +171,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: "/api/graphql", json_data={ "query": query, - "variables": {"id": model_id}, + "variables": {"id": int(model_id)}, }, ) @@ -180,7 +180,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("aiModel", {}) + result: Dict[str, Any] = response.get("data", {}).get("getAiModel", {}) return result def list_all( From e06c740829c7e6dd8e60cfab5d3d4b94dd82bd89 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 11 Feb 2026 06:24:06 +0000 Subject: [PATCH 105/127] Bump SDK version to 0.4.18 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index b33e846..cf753f5 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.17" +__version__ = "0.4.18" diff --git a/pyproject.toml b/pyproject.toml index 9058bcb..721b5d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.17" +version = "0.4.18" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.17" +python_version = "0.4.18" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 2cf1312934d63d4d6af77591154db8f19db7a021 Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 11 Feb 2026 11:57:54 +0530 Subject: [PATCH 106/127] fix sdk test --- tests/test_aimodels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aimodels.py b/tests/test_aimodels.py index 600e63e..b5bcd43 100644 --- a/tests/test_aimodels.py +++ b/tests/test_aimodels.py @@ -56,7 +56,7 @@ def test_get_model_by_id_graphql(self, mock_post: MagicMock) -> None: """Test get AI model by ID using GraphQL.""" mock_post.return_value = { "data": { - "aiModel": { + "getAiModel": { "id": "123", "displayName": "Test Model", "description": "A test model", From e89b46d9c1f463c35f584cc39b15afb75c3babde Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 11 Feb 2026 06:35:33 +0000 Subject: [PATCH 107/127] Bump SDK version to 0.4.19 --- dataspace_sdk/__version__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py index cf753f5..cff3639 100644 --- a/dataspace_sdk/__version__.py +++ b/dataspace_sdk/__version__.py @@ -1,3 +1,3 @@ """Version information for DataSpace SDK.""" -__version__ = "0.4.18" +__version__ = "0.4.19" diff --git a/pyproject.toml b/pyproject.toml index 721b5d2..62b39ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.18" +version = "0.4.19" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.18" +python_version = "0.4.19" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false From 34893bd634ad8678e26d6c519bfb7dc31204f54d Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 11 Feb 2026 14:01:20 +0530 Subject: [PATCH 108/127] add slug to user info --- api/views/auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/views/auth.py b/api/views/auth.py index ef050d0..94d98dc 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -101,6 +101,7 @@ def get(self, request: Request) -> Response: "homepage": org.organization.homepage, # type: ignore[attr-defined] "created": org.organization.created, # type: ignore[attr-defined] "updated": org.organization.modified, # type: ignore[attr-defined] + "slug": org.organization.slug, # type: ignore[attr-defined] } for org in user.organizationmembership_set.all() # type: ignore[union-attr, arg-type] ], From 1823c301904085b2b9a2efd2dc3cdb065532f2ec Mon Sep 17 00:00:00 2001 From: dc Date: Wed, 11 Feb 2026 14:02:02 +0530 Subject: [PATCH 109/127] fix: agg accessed before initialization --- api/views/paginated_elastic_view.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/api/views/paginated_elastic_view.py b/api/views/paginated_elastic_view.py index 038f8d7..7a952a3 100644 --- a/api/views/paginated_elastic_view.py +++ b/api/views/paginated_elastic_view.py @@ -47,9 +47,7 @@ def get_search(self) -> SearchType: """Get search instance.""" if hasattr(self.document_class, "search"): return self.document_class.search() # type: ignore - raise AttributeError( - f"{self.document_class.__name__} does not have a search method" - ) + raise AttributeError(f"{self.document_class.__name__} does not have a search method") def get(self, request: HttpRequest) -> Response: """Handle GET request and return paginated search results.""" @@ -99,10 +97,10 @@ def get(self, request: HttpRequest) -> Response: aggregations.pop("metadata") for agg in metadata_aggregations: label: str = agg["key"]["metadata_label"] - value: str = agg["key"].get("metadata_value", "") - if label not in aggregations: - aggregations[label] = {} - aggregations[label][value] = agg["doc_count"] + value: str = agg["key"].get("metadata_value", "") + if label not in aggregations: + aggregations[label] = {} + aggregations[label][value] = agg["doc_count"] if "catalogs" in aggregations: aggregations.pop("catalogs") @@ -170,9 +168,7 @@ def get(self, request: HttpRequest) -> Response: aggregations["running_status"][agg["key"]] = agg["doc_count"] if "is_individual_usecase" in aggregations: - is_individual_usecase_agg = aggregations["is_individual_usecase"][ - "buckets" - ] + is_individual_usecase_agg = aggregations["is_individual_usecase"]["buckets"] aggregations.pop("is_individual_usecase") aggregations["is_individual_usecase"] = {} for agg in is_individual_usecase_agg: From fdfc07817e1251a7b1ea3d153f3401cd99beb0b6 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 19:48:39 +0530 Subject: [PATCH 110/127] add aimodel signals to update index --- api/signals/__init__.py | 2 +- api/signals/aimodel_signals.py | 83 ++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 api/signals/aimodel_signals.py diff --git a/api/signals/__init__.py b/api/signals/__init__.py index 16d9e25..29e3798 100644 --- a/api/signals/__init__.py +++ b/api/signals/__init__.py @@ -1,2 +1,2 @@ # Import signals to register them -from api.signals import dataset_signals, usecase_signals +from api.signals import aimodel_signals, dataset_signals, usecase_signals diff --git a/api/signals/aimodel_signals.py b/api/signals/aimodel_signals.py new file mode 100644 index 0000000..3c65808 --- /dev/null +++ b/api/signals/aimodel_signals.py @@ -0,0 +1,83 @@ +from typing import Any + +import structlog +from django.db.models.signals import post_delete, pre_save +from django.dispatch import receiver + +from api.models.AIModel import AIModel +from api.utils.enums import AIModelStatus +from search.documents import AIModelDocument + +logger = structlog.get_logger(__name__) + +INDEXABLE_STATUSES = {AIModelStatus.ACTIVE, AIModelStatus.APPROVED} + + +def _should_be_indexed(instance: AIModel) -> bool: + """Return True when the AI model should exist in Elasticsearch.""" + return instance.is_public and instance.is_active and instance.status in INDEXABLE_STATUSES + + +@receiver(pre_save, sender=AIModel) +def handle_aimodel_visibility(sender: Any, instance: AIModel, **kwargs: Any) -> None: + """Sync Elasticsearch document whenever publish/visibility fields change.""" + if not instance.pk: + # New objects are handled by django-elasticsearch-dsl signal processor + return + + try: + original = AIModel.objects.get(pk=instance.pk) + except AIModel.DoesNotExist: + return + + was_indexable = _should_be_indexed(original) + is_indexable = _should_be_indexed(instance) + + if was_indexable == is_indexable and is_indexable: + # Still indexable, just refresh contents + action = "update" + elif was_indexable and not is_indexable: + action = "delete" + elif not was_indexable and is_indexable: + action = "add" + else: + # Neither was nor is indexable; nothing to do + return + + try: + document = AIModelDocument.get(id=instance.id, ignore=404) + if action == "delete": + if document: + document.delete() + logger.info("Removed AI model from Elasticsearch index", model_id=instance.id) + else: + if document: + document.update(instance) + else: + AIModelDocument().update(instance) + logger.info( + "Synced AI model to Elasticsearch index", model_id=instance.id, action=action + ) + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Failed to sync AI model search document", + model_id=instance.id, + action=action, + error=str(exc), + ) + + +@receiver(post_delete, sender=AIModel) +def remove_aimodel_document(sender: Any, instance: AIModel, **kwargs: Any) -> None: + """Ensure Elasticsearch document gets deleted when the model is removed.""" + try: + document = AIModelDocument.get(id=instance.id, ignore=404) + if document: + document.delete() + logger.info("Removed deleted AI model from Elasticsearch index", model_id=instance.id) + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Failed to delete AI model search document", + model_id=instance.id, + error=str(exc), + ) From 23949809441e8e82d459cb1b65ae213b77710934 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 22:12:33 +0530 Subject: [PATCH 111/127] add collaborative document, views and update signals --- DataSpace/settings.py | 1 + api/signals/collaborative_signals.py | 70 +++++ api/urls.py | 6 + api/views/search_collaborative.py | 333 +++++++++++++++++++++ search/documents/__init__.py | 1 + search/documents/collaborative_document.py | 331 ++++++++++++++++++++ 6 files changed, 742 insertions(+) create mode 100644 api/signals/collaborative_signals.py create mode 100644 api/views/search_collaborative.py create mode 100644 search/documents/collaborative_document.py diff --git a/DataSpace/settings.py b/DataSpace/settings.py index b9a19bf..1a6f961 100644 --- a/DataSpace/settings.py +++ b/DataSpace/settings.py @@ -277,6 +277,7 @@ "search.documents.dataset_document": "dataset", "search.documents.usecase_document": "usecase", "search.documents.aimodel_document": "aimodel", + "search.documents.collaborative_document": "collaborative", } diff --git a/api/signals/collaborative_signals.py b/api/signals/collaborative_signals.py new file mode 100644 index 0000000..2556048 --- /dev/null +++ b/api/signals/collaborative_signals.py @@ -0,0 +1,70 @@ +from typing import Any + +import structlog +from django.core.cache import cache +from django.db.models.signals import pre_save +from django.dispatch import receiver + +from api.models.Collaborative import Collaborative +from api.utils.enums import CollaborativeStatus +from search.documents import CollaborativeDocument + +from .dataset_signals import SEARCH_CACHE_VERSION_KEY + +logger = structlog.get_logger(__name__) + + +@receiver(pre_save, sender=Collaborative) +def handle_collaborative_publication(sender: Any, instance: Collaborative, **kwargs: Any) -> None: + """Sync Elasticsearch index when collaborative publication state changes.""" + + try: + if not instance.pk: + return + + original = Collaborative.objects.get(pk=instance.pk) + + status_changing_to_published = ( + original.status != CollaborativeStatus.PUBLISHED + and instance.status == CollaborativeStatus.PUBLISHED + ) + status_changing_from_published = ( + original.status == CollaborativeStatus.PUBLISHED + and instance.status != CollaborativeStatus.PUBLISHED + ) + remains_published = ( + original.status == CollaborativeStatus.PUBLISHED + and instance.status == CollaborativeStatus.PUBLISHED + ) + + if status_changing_to_published or status_changing_from_published: + version = cache.get(SEARCH_CACHE_VERSION_KEY, 0) + cache.set(SEARCH_CACHE_VERSION_KEY, version + 1) + logger.info("Invalidated search cache for collaborative", collaborative_id=instance.id) + + if status_changing_from_published: + document = CollaborativeDocument.get(id=instance.id, ignore=404) + if document: + document.delete() + logger.info( + "Removed collaborative from Elasticsearch index", + collaborative_id=instance.id, + ) + elif status_changing_to_published or remains_published: + document = CollaborativeDocument.get(id=instance.id, ignore=404) + if document: + document.update(instance) + else: + CollaborativeDocument().update(instance) + logger.info( + "Synced collaborative to Elasticsearch index", + collaborative_id=instance.id, + ) + + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Error in collaborative publication signal handler", + collaborative_id=getattr(instance, "id", None), + error=str(exc), + ) + # Avoid raising to prevent save failures diff --git a/api/urls.py b/api/urls.py index 86cbb9e..335de6d 100644 --- a/api/urls.py +++ b/api/urls.py @@ -14,6 +14,7 @@ download, generate_dynamic_chart, search_aimodel, + search_collaborative, search_dataset, search_unified, search_usecase, @@ -48,6 +49,11 @@ path("search/dataset/", search_dataset.SearchDataset.as_view(), name="search_dataset"), path("search/usecase/", search_usecase.SearchUseCase.as_view(), name="search_usecase"), path("search/aimodel/", search_aimodel.SearchAIModel.as_view(), name="search_aimodel"), + path( + "search/collaborative/", + search_collaborative.SearchCollaborative.as_view(), + name="search_collaborative", + ), path("search/unified/", search_unified.UnifiedSearch.as_view(), name="search_unified"), path( "aimodels//", diff --git a/api/views/search_collaborative.py b/api/views/search_collaborative.py new file mode 100644 index 0000000..17732c6 --- /dev/null +++ b/api/views/search_collaborative.py @@ -0,0 +1,333 @@ +import ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import structlog +from elasticsearch_dsl import A +from elasticsearch_dsl import Q as ESQ +from elasticsearch_dsl import Search +from elasticsearch_dsl.query import Query as ESQuery +from rest_framework import serializers +from rest_framework.permissions import AllowAny + +from api.models import Collaborative, CollaborativeMetadata, Geography, Metadata +from api.utils.telemetry_utils import trace_method +from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView +from search.documents import CollaborativeDocument + +logger = structlog.get_logger(__name__) + + +class MetadataSerializer(serializers.Serializer): + label = serializers.CharField(allow_blank=True) # type: ignore + + +class CollaborativeMetadataSerializer(serializers.ModelSerializer): + metadata_item = MetadataSerializer() + + class Meta: + model = CollaborativeMetadata + fields = ["metadata_item", "value"] + + def to_representation(self, instance: CollaborativeMetadata) -> Dict[str, Any]: + representation = super().to_representation(instance) + + if isinstance(representation["value"], str): + try: + value_list = ast.literal_eval(representation["value"]) + if isinstance(value_list, list): + representation["value"] = ", ".join(str(x) for x in value_list) + except (ValueError, SyntaxError): + pass + + return cast(Dict[str, Any], representation) + + def to_internal_value(self, data: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(data.get("value"), str): + try: + value = data["value"] + data["value"] = value.split(", ") if value else [] + except (ValueError, SyntaxError): + pass + + return cast(Dict[str, Any], super().to_internal_value(data)) + + +class CollaborativeDocumentSerializer(serializers.ModelSerializer): + metadata = CollaborativeMetadataSerializer(many=True) + tags = serializers.ListField() + sectors = serializers.ListField(default=[]) + geographies = serializers.ListField(default=[]) + slug = serializers.CharField() + is_individual_collaborative = serializers.BooleanField() + website = serializers.CharField(required=False, allow_blank=True) + contact_email = serializers.EmailField(required=False, allow_blank=True) + platform_url = serializers.CharField(required=False, allow_blank=True) + started_on = serializers.DateTimeField(required=False, allow_null=True) + completed_on = serializers.DateTimeField(required=False, allow_null=True) + + class OrganizationSerializer(serializers.Serializer): + name = serializers.CharField() + logo = serializers.CharField() + + class UserSerializer(serializers.Serializer): + name = serializers.CharField() + bio = serializers.CharField() + profile_picture = serializers.CharField() + + class ContributorSerializer(serializers.Serializer): + name = serializers.CharField() + bio = serializers.CharField() + profile_picture = serializers.CharField() + + class RelatedOrganizationSerializer(serializers.Serializer): + name = serializers.CharField() + logo = serializers.CharField() + relationship_type = serializers.CharField() + + class DatasetSerializer(serializers.Serializer): + title = serializers.CharField() + description = serializers.CharField() + slug = serializers.CharField() + + class UseCaseSerializer(serializers.Serializer): + title = serializers.CharField() + summary = serializers.CharField() + slug = serializers.CharField() + + organization = OrganizationSerializer() + user = UserSerializer() + contributors = ContributorSerializer(many=True) + organizations = RelatedOrganizationSerializer(many=True) + datasets = DatasetSerializer(many=True) + use_cases = UseCaseSerializer(many=True) + + class Meta: + model = Collaborative + fields = [ + "id", + "title", + "summary", + "slug", + "created", + "modified", + "status", + "metadata", + "tags", + "sectors", + "geographies", + "is_individual_collaborative", + "organization", + "user", + "contributors", + "organizations", + "datasets", + "use_cases", + "website", + "contact_email", + "platform_url", + "started_on", + "completed_on", + ] + + +class SearchCollaborative(PaginatedElasticSearchAPIView): + serializer_class = CollaborativeDocumentSerializer + document_class = CollaborativeDocument + permission_classes = [AllowAny] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.searchable_fields: List[str] + self.aggregations: Dict[str, str] + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() + self.logger = structlog.get_logger(__name__) + + @trace_method( + name="get_searchable_and_aggregations", + attributes={"component": "search_collaborative"}, + ) + def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: + searchable_fields = [ + "title", + "summary", + "tags", + "sectors", + "user.name", + "organization.name", + "contributors.name", + "datasets.title", + "datasets.description", + "use_cases.title", + "use_cases.summary", + "metadata.value", + ] + + aggregations: Dict[str, str] = { + "tags.raw": "terms", + "sectors.raw": "terms", + "geographies.raw": "terms", + "status": "terms", + } + + filterable_metadata = Metadata.objects.filter(filterable=True).all() + for metadata in filterable_metadata: + aggregations[f"metadata.{metadata.label}"] = "terms" # type: ignore + + return searchable_fields, aggregations + + @trace_method(name="add_aggregations", attributes={"component": "search_collaborative"}) + def add_aggregations(self, search: Search) -> Search: + aggregate_fields: List[str] = [] + for aggregation_field in self.aggregations: + if aggregation_field.startswith("metadata."): + aggregate_fields.append(aggregation_field.split(".")[1]) + else: + search.aggs.bucket( + aggregation_field.replace(".raw", ""), + self.aggregations[aggregation_field], + field=aggregation_field, + ) + + if aggregate_fields: + metadata_qs = Metadata.objects.filter(filterable=True) + filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + + metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") + composite_agg = A( + "composite", + sources=[ + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, + {"metadata_value": {"terms": {"field": "metadata.value"}}}, + ], + size=10000, + ) + metadata_filter = A( + "filter", + { + "bool": { + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] + } + }, + ) + metadata_bucket.bucket("filtered_metadata", metadata_filter).bucket( + "composite_agg", composite_agg + ) + + return search + + @trace_method(name="generate_q_expression", attributes={"component": "search_collaborative"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: + if query: + queries: List[ESQuery] = [] + for field in self.searchable_fields: + if field.startswith("datasets."): + queries.append( + ESQ( + "nested", + path="datasets", + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + elif field.startswith("use_cases."): + queries.append( + ESQ( + "nested", + path="use_cases", + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + elif ( + field.startswith("user.") + or field.startswith("organization.") + or field.startswith("contributors.") + or field.startswith("organizations.") + ): + path = field.split(".")[0] + queries.append( + ESQ( + "nested", + path=path, + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + else: + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) + else: + queries = [ESQ("match_all")] + + return ESQ("bool", should=queries, minimum_should_match=1) + + @trace_method(name="add_filters", attributes={"component": "search_collaborative"}) + def add_filters(self, filters: Dict[str, str], search: Search) -> Search: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + + for filter in filters: + if filter in excluded_labels: + continue + elif filter in ["tags", "sectors", "geographies"]: + raw_filter = filter + ".raw" + if raw_filter in self.aggregations: + filter_values = filters[filter].split(",") + + if filter == "geographies": + filter_values = Geography.get_geography_names_with_descendants( + filter_values + ) + + search = search.filter("terms", **{raw_filter: filter_values}) + else: + search = search.filter("term", **{filter: filters[filter]}) + elif filter in ["status", "is_individual_collaborative"]: + search = search.filter("term", **{filter: filters[filter]}) + elif filter in ["user.name", "organization.name"]: + path = filter.split(".")[0] + search = search.filter( + "nested", + path=path, + query={"bool": {"must": {"term": {filter: filters[filter]}}}}, + ) + elif filter in ["datasets.slug", "use_cases.slug"]: + path = filter.split(".")[0] + search = search.filter( + "nested", + path=path, + query={"bool": {"must": {"term": {filter: filters[filter]}}}}, + ) + else: + search = search.filter( + "nested", + path="metadata", + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, + ) + return search + + @trace_method(name="add_sort", attributes={"component": "search_collaborative"}) + def add_sort(self, sort: str, search: Search, order: str) -> Search: + if sort == "alphabetical": + search = search.sort({"title.raw": {"order": order}}) + elif sort == "recent": + search = search.sort({"modified": {"order": order}}) + elif sort == "started": + search = search.sort({"started_on": {"order": order}}) + elif sort == "completed": + search = search.sort({"completed_on": {"order": order}}) + return search diff --git a/search/documents/__init__.py b/search/documents/__init__.py index f6441ac..64d15f8 100644 --- a/search/documents/__init__.py +++ b/search/documents/__init__.py @@ -1,3 +1,4 @@ from search.documents.aimodel_document import AIModelDocument +from search.documents.collaborative_document import CollaborativeDocument from search.documents.dataset_document import DatasetDocument from search.documents.usecase_document import UseCaseDocument diff --git a/search/documents/collaborative_document.py b/search/documents/collaborative_document.py new file mode 100644 index 0000000..563e3c0 --- /dev/null +++ b/search/documents/collaborative_document.py @@ -0,0 +1,331 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from django_elasticsearch_dsl import Document, Index, KeywordField, fields + +from api.models import ( + Collaborative, + CollaborativeMetadata, + CollaborativeOrganizationRelationship, + Dataset, + Geography, + Metadata, + Organization, + Sector, + UseCase, +) +from api.utils.enums import CollaborativeStatus +from authorization.models import User +from DataSpace import settings +from search.documents.analysers import html_strip, ngram_analyser + +if TYPE_CHECKING: + from api.models import CollaborativeOrganizationRelationship as RelationshipModel + from api.models import Dataset as DatasetModel + from api.models import Organization as OrganizationModel + from api.models import UseCase as UseCaseModel + +INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[__name__]) +INDEX.settings(number_of_shards=1, number_of_replicas=0) + + +@INDEX.doc_type +class CollaborativeDocument(Document): + """Elasticsearch document for Collaborative model.""" + + metadata = fields.NestedField( + properties={ + "value": KeywordField(multi=True), + "raw": KeywordField(multi=True), + "metadata_item": fields.ObjectField(properties={"label": KeywordField(multi=False)}), + } + ) + + datasets = fields.NestedField( + properties={ + "title": fields.TextField(analyzer=ngram_analyser), + "description": fields.TextField(analyzer=html_strip), + "slug": fields.KeywordField(), + } + ) + + use_cases = fields.NestedField( + properties={ + "title": fields.TextField(analyzer=ngram_analyser), + "summary": fields.TextField(analyzer=html_strip), + "slug": fields.KeywordField(), + } + ) + + title = fields.TextField( + analyzer=ngram_analyser, + fields={ + "raw": KeywordField(multi=False), + }, + ) + + summary = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + + logo = fields.TextField(analyzer=ngram_analyser) + cover_image = fields.TextField(analyzer=ngram_analyser) + + status = fields.KeywordField() + slug = fields.KeywordField() + + tags = fields.TextField( + attr="tags_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + sectors = fields.TextField( + attr="sectors_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + geographies = fields.TextField( + attr="geographies_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + organization = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "logo": fields.TextField(analyzer=ngram_analyser), + } + ) + + user = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "bio": fields.TextField(analyzer=html_strip), + "profile_picture": fields.TextField(analyzer=ngram_analyser), + } + ) + + contributors = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "bio": fields.TextField(analyzer=html_strip), + "profile_picture": fields.TextField(analyzer=ngram_analyser), + } + ) + + organizations = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "logo": fields.TextField(analyzer=ngram_analyser), + "relationship_type": fields.KeywordField(), + } + ) + + is_individual_collaborative = fields.BooleanField(attr="is_individual_collaborative") + + website = fields.TextField(analyzer=ngram_analyser) + contact_email = fields.KeywordField() + platform_url = fields.TextField(analyzer=ngram_analyser) + started_on = fields.DateField() + completed_on = fields.DateField() + + def prepare_metadata(self, instance: Collaborative) -> List[Dict[str, Any]]: + processed_metadata: List[Dict[str, Any]] = [] + for meta in instance.metadata.all(): # type: CollaborativeMetadata + if not meta.metadata_item: + continue + + value_list = ( + [val.strip() for val in meta.value.split(",")] + if isinstance(meta.value, str) and "," in meta.value + else [meta.value] + ) + processed_metadata.append( + { + "value": value_list, + "metadata_item": {"label": meta.metadata_item.label}, + } + ) + return processed_metadata + + def prepare_datasets(self, instance: Collaborative) -> List[Dict[str, str]]: + datasets_data: List[Dict[str, str]] = [] + for dataset in instance.datasets.all(): + datasets_data.append( + { + "title": dataset.title or "", # type: ignore[attr-defined] + "description": dataset.description or "", # type: ignore[attr-defined] + "slug": dataset.slug or "", # type: ignore[attr-defined] + } + ) + return datasets_data + + def prepare_use_cases(self, instance: Collaborative) -> List[Dict[str, str]]: + use_cases_data: List[Dict[str, str]] = [] + for use_case in instance.use_cases.all(): + use_cases_data.append( + { + "title": use_case.title or "", # type: ignore[attr-defined] + "summary": use_case.summary or "", # type: ignore[attr-defined] + "slug": use_case.slug or "", # type: ignore[attr-defined] + } + ) + return use_cases_data + + def prepare_organization(self, instance: Collaborative) -> Optional[Dict[str, str]]: + if instance.organization: + org = instance.organization + logo_url = org.logo.url if org.logo else "" + return {"name": org.name, "logo": logo_url} + return None + + def prepare_user(self, instance: Collaborative) -> Optional[Dict[str, str]]: + if instance.user: + return { + "name": instance.user.full_name, + "bio": instance.user.bio or "", + "profile_picture": ( + instance.user.profile_picture.url if instance.user.profile_picture else "" + ), + } + return None + + def prepare_contributors(self, instance: Collaborative) -> List[Dict[str, str]]: + contributors_data: List[Dict[str, str]] = [] + for contributor in instance.contributors.all(): + contributors_data.append( + { + "name": contributor.full_name, # type: ignore + "bio": contributor.bio or "", # type: ignore + "profile_picture": ( + contributor.profile_picture.url # type: ignore + if contributor.profile_picture # type: ignore + else "" + ), + } + ) + return contributors_data + + def prepare_organizations(self, instance: Collaborative) -> List[Dict[str, str]]: + organizations_data: List[Dict[str, str]] = [] + relationships = CollaborativeOrganizationRelationship.objects.filter(collaborative=instance) + for relationship in relationships: + org = relationship.organization # type: ignore[attr-defined] + logo_url = org.logo.url if org.logo else "" + organizations_data.append( + { + "name": org.name, # type: ignore[attr-defined] + "logo": logo_url, + "relationship_type": relationship.relationship_type, # type: ignore[attr-defined] + } + ) + return organizations_data + + def prepare_logo(self, instance: Collaborative) -> str: + if instance.logo: + return str(instance.logo.path.replace("/code/files/", "")) + return "" + + def prepare_cover_image(self, instance: Collaborative) -> str: + if instance.cover_image: + return str(instance.cover_image.path.replace("/code/files/", "")) + return "" + + def should_index_object(self, obj: Collaborative) -> bool: + return obj.status == CollaborativeStatus.PUBLISHED + + def save(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - thin wrapper + if self.status == CollaborativeStatus.PUBLISHED: + super().save(*args, **kwargs) + else: + self.delete(ignore=404) + + def delete(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - thin wrapper + super().delete(*args, **kwargs) + + def get_queryset(self) -> Any: + return ( + super(CollaborativeDocument, self) + .get_queryset() + .filter(status=CollaborativeStatus.PUBLISHED) + ) + + def get_instances_from_related( + self, + related_instance: Union[ + Dataset, + UseCase, + Metadata, + CollaborativeMetadata, + Sector, + Organization, + User, + Geography, + ], + ) -> Optional[Union[Collaborative, List[Collaborative]]]: + if isinstance(related_instance, Dataset): + return list(related_instance.collaborative_set.all()) # type: ignore[attr-defined] + if isinstance(related_instance, UseCase): + return list(related_instance.collaborative_set.all()) # type: ignore[attr-defined] + if isinstance(related_instance, Metadata): + collab_metadata_objects = related_instance.collaborativemetadata_set.all() # type: ignore[attr-defined] + return [obj.collaborative for obj in collab_metadata_objects] # type: ignore[attr-defined] + if isinstance(related_instance, CollaborativeMetadata): + return related_instance.collaborative # type: ignore[attr-defined] + if isinstance(related_instance, Sector): + return list(related_instance.collaboratives.all()) + if isinstance(related_instance, Organization): + primary = list(related_instance.collaborative_set.all()) + related = list(related_instance.related_collaboratives.all()) + return primary + related + if isinstance(related_instance, User): + owned = list(related_instance.collaborative_set.all()) + contributed = list(related_instance.contributed_collaboratives.all()) + return owned + contributed + if isinstance(related_instance, Geography): + return list(related_instance.collaboratives.all()) + return None + + class Django: + model = Collaborative + + fields = [ + "id", + "created", + "modified", + ] + + related_models = [ + Dataset, + UseCase, + Metadata, + CollaborativeMetadata, + Sector, + Organization, + User, + Geography, + ] From ffc5b86344c4305e794f235861fcc2fa6831bae2 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 22:12:54 +0530 Subject: [PATCH 112/127] add collaborative to unfied search --- api/views/search_unified.py | 90 +++++++++++++++++++++++++++---------- 1 file changed, 67 insertions(+), 23 deletions(-) diff --git a/api/views/search_unified.py b/api/views/search_unified.py index a6b4257..5a4df6f 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -14,9 +14,15 @@ from api.models import Dataset, Geography, Metadata, UseCase from api.models.AIModel import AIModel +from api.models.Collaborative import Collaborative from api.utils.telemetry_utils import trace_method from DataSpace import settings -from search.documents import AIModelDocument, DatasetDocument, UseCaseDocument +from search.documents import ( + AIModelDocument, + CollaborativeDocument, + DatasetDocument, + UseCaseDocument, +) logger = structlog.get_logger(__name__) @@ -25,7 +31,7 @@ class UnifiedSearchResultSerializer(serializers.Serializer): """Serializer for unified search results.""" id = serializers.CharField() - type = serializers.CharField() # 'dataset', 'usecase', or 'aimodel' + type = serializers.CharField() # 'dataset', 'usecase', 'aimodel', or 'collaborative' title = serializers.CharField() description = serializers.CharField() slug = serializers.CharField(required=False) @@ -102,6 +108,12 @@ def _get_index_names(self, types_list: List[str]) -> List[str]: ) index_names.append(aimodel_index) + if "collaborative" in types_list: + collaborative_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.collaborative_document", "collaborative" + ) + index_names.append(collaborative_index) + return index_names def _build_unified_query(self, query: str) -> ESQ: @@ -172,6 +184,43 @@ def _build_unified_query(self, query: str) -> ESQ: ] ) + # Collaborative nested fields + common_queries.extend( + [ + ESQ( + "nested", + path="datasets", + query=ESQ( + "multi_match", + query=query, + fields=["datasets.title", "datasets.description"], + fuzziness="AUTO", + ), + ignore_unmapped=True, + ), + ESQ( + "nested", + path="use_cases", + query=ESQ( + "multi_match", + query=query, + fields=["use_cases.title", "use_cases.summary"], + fuzziness="AUTO", + ), + ignore_unmapped=True, + ), + ESQ( + "nested", + path="contributors", + query=ESQ( + "match", + **{"contributors.name": {"query": query, "fuzziness": "AUTO"}}, + ), + ignore_unmapped=True, + ), + ] + ) + # Organization and user (common across types) common_queries.extend( [ @@ -187,9 +236,7 @@ def _build_unified_query(self, query: str) -> ESQ: ESQ( "nested", path="user", - query=ESQ( - "match", **{"user.name": {"query": query, "fuzziness": "AUTO"}} - ), + query=ESQ("match", **{"user.name": {"query": query, "fuzziness": "AUTO"}}), ignore_unmapped=True, ), ] @@ -209,9 +256,7 @@ def _apply_filters(self, search: Search, filters: Dict[str, str]) -> Search: if "geographies" in filters: filter_values = filters["geographies"].split(",") - filter_values = Geography.get_geography_names_with_descendants( - filter_values - ) + filter_values = Geography.get_geography_names_with_descendants(filter_values) search = search.filter("terms", **{"geographies.raw": filter_values}) if "status" in filters: @@ -233,6 +278,8 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["type"] = "usecase" elif "aimodel" in index_name: result["type"] = "aimodel" + elif "collaborative" in index_name: + result["type"] = "collaborative" else: result["type"] = "unknown" @@ -256,6 +303,11 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["created"] = result["created_at"] if "updated_at" in result: result["modified"] = result["updated_at"] + elif result["type"] == "collaborative": + if "summary" in result: + result["description"] = result.get("summary", "") + if "title" not in result: + result["title"] = "" else: # dataset if "title" not in result: result["title"] = "" @@ -323,6 +375,8 @@ def perform_unified_search( aggregations["types"]["usecase"] = bucket["doc_count"] elif "aimodel" in index_name: aggregations["types"]["aimodel"] = bucket["doc_count"] + elif "collaborative" in index_name: + aggregations["types"]["collaborative"] = bucket["doc_count"] # Process other aggregations for agg_name in ["tags", "sectors", "geographies", "status"]: @@ -331,11 +385,7 @@ def perform_unified_search( for bucket in aggs_dict[agg_name]["buckets"]: aggregations[agg_name][bucket["key"]] = bucket["doc_count"] - total = ( - response.hits.total.value - if hasattr(response.hits.total, "value") - else len(results) - ) + total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results) return results, total, aggregations @@ -347,7 +397,7 @@ def get(self, request: Any) -> Response: page: int = int(request.GET.get("page", 1)) size: int = int(request.GET.get("size", 10)) entity_types: str = request.GET.get( - "types", "dataset,usecase,aimodel" + "types", "dataset,usecase,aimodel,collaborative" ) # Which entity types to search # Parse entity types @@ -383,9 +433,7 @@ def get(self, request: Any) -> Response: self.logger.error("unified_search_error", error=str(e), exc_info=True) return Response({"error": "An internal error has occurred."}, status=500) - def _build_aggregations( - self, results: List[Dict[str, Any]] - ) -> Dict[str, Dict[str, int]]: + def _build_aggregations(self, results: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: """Build aggregations from results.""" aggregations: Dict[str, Dict[str, int]] = { "types": {}, @@ -398,9 +446,7 @@ def _build_aggregations( for result in results: # Count by type result_type = result.get("type", "unknown") - aggregations["types"][result_type] = ( - aggregations["types"].get(result_type, 0) + 1 - ) + aggregations["types"][result_type] = aggregations["types"].get(result_type, 0) + 1 # Count by tags for tag in result.get("tags", []): @@ -408,9 +454,7 @@ def _build_aggregations( # Count by sectors for sector in result.get("sectors", []): - aggregations["sectors"][sector] = ( - aggregations["sectors"].get(sector, 0) + 1 - ) + aggregations["sectors"][sector] = aggregations["sectors"].get(sector, 0) + 1 # Count by geographies for geography in result.get("geographies", []): From a60f4270a64d39dcf986e2338005ef58f49495e1 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 22:42:47 +0530 Subject: [PATCH 113/127] add collborative specific fields to serializer --- api/views/search_unified.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/api/views/search_unified.py b/api/views/search_unified.py index 5a4df6f..96a0df2 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -76,6 +76,14 @@ class UserSerializer(serializers.Serializer): provider = serializers.CharField(required=False) is_individual_model = serializers.BooleanField(required=False) + # Collaborative specific + is_individual_collaborative = serializers.BooleanField(required=False) + website = serializers.CharField(required=False) + contact_email = serializers.CharField(required=False) + platform_url = serializers.CharField(required=False) + started_on = serializers.DateTimeField(required=False) + completed_on = serializers.DateTimeField(required=False) + class UnifiedSearch(APIView): """View for unified search across datasets, usecases, and aimodels.""" From 585da033f93f3ed880463574891371d200ab44bc Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 23:00:01 +0530 Subject: [PATCH 114/127] add dataset count to collaborative document --- search/documents/collaborative_document.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/search/documents/collaborative_document.py b/search/documents/collaborative_document.py index 563e3c0..2661deb 100644 --- a/search/documents/collaborative_document.py +++ b/search/documents/collaborative_document.py @@ -152,6 +152,7 @@ class CollaborativeDocument(Document): platform_url = fields.TextField(analyzer=ngram_analyser) started_on = fields.DateField() completed_on = fields.DateField() + dataset_count = fields.IntegerField() def prepare_metadata(self, instance: Collaborative) -> List[Dict[str, Any]]: processed_metadata: List[Dict[str, Any]] = [] @@ -255,6 +256,9 @@ def prepare_cover_image(self, instance: Collaborative) -> str: return str(instance.cover_image.path.replace("/code/files/", "")) return "" + def prepare_dataset_count(self, instance: Collaborative) -> int: + return instance.datasets.count() + def should_index_object(self, obj: Collaborative) -> bool: return obj.status == CollaborativeStatus.PUBLISHED From 1aefe8228ff641580c05f317b0e8514d46d35f98 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 23:00:01 +0530 Subject: [PATCH 115/127] add dataset count to collaborative document --- DataSpace/settings.py | 3 + api/urls.py | 2 + api/views/search_publisher.py | 306 +++++++++++++++++++++ api/views/search_unified.py | 58 +++- search/documents/__init__.py | 4 + search/documents/publisher_document.py | 360 +++++++++++++++++++++++++ 6 files changed, 725 insertions(+), 8 deletions(-) create mode 100644 api/views/search_publisher.py create mode 100644 search/documents/publisher_document.py diff --git a/DataSpace/settings.py b/DataSpace/settings.py index 1a6f961..7996d81 100644 --- a/DataSpace/settings.py +++ b/DataSpace/settings.py @@ -278,6 +278,9 @@ "search.documents.usecase_document": "usecase", "search.documents.aimodel_document": "aimodel", "search.documents.collaborative_document": "collaborative", + "search.documents.publisher_document.OrganizationPublisherDocument": "organization_publisher", + "search.documents.publisher_document.UserPublisherDocument": "user_publisher", + "search.documents.publisher_document": "publisher", } diff --git a/api/urls.py b/api/urls.py index 335de6d..6d67c40 100644 --- a/api/urls.py +++ b/api/urls.py @@ -16,6 +16,7 @@ search_aimodel, search_collaborative, search_dataset, + search_publisher, search_unified, search_usecase, trending_datasets, @@ -55,6 +56,7 @@ name="search_collaborative", ), path("search/unified/", search_unified.UnifiedSearch.as_view(), name="search_unified"), + path("search/publisher/", search_publisher.SearchPublisher.as_view(), name="search_publisher"), path( "aimodels//", aimodel_detail.AIModelDetailView.as_view(), diff --git a/api/views/search_publisher.py b/api/views/search_publisher.py new file mode 100644 index 0000000..ff72a9d --- /dev/null +++ b/api/views/search_publisher.py @@ -0,0 +1,306 @@ +from typing import Any, Dict, List, Tuple + +import structlog +from elasticsearch_dsl import Q as ESQ +from elasticsearch_dsl import Search +from rest_framework import serializers +from rest_framework.permissions import AllowAny +from rest_framework.response import Response + +from api.utils.telemetry_utils import trace_method, track_metrics +from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView +from search.documents import OrganizationPublisherDocument, UserPublisherDocument + +logger = structlog.get_logger(__name__) + + +class PublisherDocumentSerializer(serializers.Serializer): + """Serializer for Publisher document (both Organization and User).""" + + id = serializers.CharField() + name = serializers.CharField() + description = serializers.CharField() + publisher_type = serializers.CharField() # 'organization' or 'user' + logo = serializers.CharField(required=False) + slug = serializers.CharField(required=False) + created = serializers.DateTimeField(required=False) + modified = serializers.DateTimeField(required=False) + + # Counts + published_datasets_count = serializers.IntegerField() + published_usecases_count = serializers.IntegerField() + members_count = serializers.IntegerField(required=False) # Only for organizations + contributed_sectors_count = serializers.IntegerField() + + # Organization specific fields + homepage = serializers.CharField(required=False) + contact_email = serializers.CharField(required=False) + organization_types = serializers.CharField(required=False) + github_profile = serializers.CharField(required=False) + linkedin_profile = serializers.CharField(required=False) + twitter_profile = serializers.CharField(required=False) + location = serializers.CharField(required=False) + + # User specific fields + bio = serializers.CharField(required=False) + profile_picture = serializers.CharField(required=False) + username = serializers.CharField(required=False) + email = serializers.CharField(required=False) + first_name = serializers.CharField(required=False) + last_name = serializers.CharField(required=False) + full_name = serializers.CharField(required=False) + + # Search fields + sectors = serializers.ListField(required=False) + + +class SearchPublisher(PaginatedElasticSearchAPIView): + """API view for searching publishers (organizations and users).""" + + serializer_class = PublisherDocumentSerializer + permission_classes = [AllowAny] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.logger = structlog.get_logger(__name__) + + def get_document_classes(self) -> List[Any]: + """Return the document classes to search.""" + return [OrganizationPublisherDocument, UserPublisherDocument] + + def get_index_names(self) -> List[str]: + """Get the index names for publisher search.""" + from DataSpace import settings + + org_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.OrganizationPublisherDocument", + "organization_publisher", + ) + user_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.UserPublisherDocument", "user_publisher" + ) + return [org_index, user_index] + + @trace_method(name="build_query", attributes={"component": "publisher_search"}) + def build_query(self, query: str) -> ESQ: + """Build the Elasticsearch query for publisher search.""" + if not query: + return ESQ("match_all") + + # Multi-field search with boosting + queries = [ + ESQ( + "multi_match", + query=query, + fields=["name^3", "full_name^3"], # Boost name fields + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=["description^2", "bio^2"], # Boost description/bio + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=["sectors^2"], # Boost sectors + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=[ + "username", + "email", + "location", + "organization_types", + "first_name", + "last_name", + ], + fuzziness="AUTO", + ), + ] + + return ESQ("bool", should=queries, minimum_should_match=1) + + @trace_method(name="apply_filters", attributes={"component": "publisher_search"}) + def apply_filters(self, search: Search, filters: Dict[str, str]) -> Search: + """Apply filters to the search query.""" + + if "publisher_type" in filters: + # Filter by publisher type (organization or user) + search = search.filter("term", publisher_type=filters["publisher_type"]) + + if "sectors" in filters: + # Filter by sectors + filter_values = filters["sectors"].split(",") + search = search.filter("terms", **{"sectors.raw": filter_values}) + + if "organization_types" in filters: + # Filter by organization types + search = search.filter("term", organization_types=filters["organization_types"]) + + if "location" in filters: + # Filter by location (fuzzy match) + search = search.filter("match", location=filters["location"]) + + return search + + @trace_method(name="build_aggregations", attributes={"component": "publisher_search"}) + def build_aggregations(self, search: Search) -> Search: + """Build aggregations for faceted search.""" + + # Publisher type aggregation + search.aggs.bucket("publisher_type", "terms", field="publisher_type") + + # Sectors aggregation + search.aggs.bucket("sectors", "terms", field="sectors.raw", size=50) + + # Organization types aggregation + search.aggs.bucket("organization_types", "terms", field="organization_types", size=20) + + # Location aggregation (top 20 locations) + search.aggs.bucket("locations", "terms", field="location.raw", size=20) + + return search + + @trace_method(name="apply_sorting", attributes={"component": "publisher_search"}) + def apply_sorting(self, search: Search, sort_by: str) -> Search: + """Apply sorting to the search query.""" + + if sort_by == "alphabetical": + search = search.sort("name.raw") + elif sort_by == "datasets_count": + search = search.sort({"published_datasets_count": {"order": "desc"}}) + elif sort_by == "usecases_count": + search = search.sort({"published_usecases_count": {"order": "desc"}}) + elif sort_by == "total_contributions": + # Sort by total datasets + usecases + search = search.sort( + { + "_script": { + "type": "number", + "script": { + "source": "doc['published_datasets_count'].value + doc['published_usecases_count'].value" + }, + "order": "desc", + } + } + ) + elif sort_by == "members_count": + # Only applicable to organizations + search = search.sort({"members_count": {"order": "desc"}}) + elif sort_by == "recent": + search = search.sort({"created": {"order": "desc"}}) + else: + # Default: relevance score + pass + + return search + + @trace_method(name="perform_search", attributes={"component": "publisher_search"}) + def perform_search( + self, + query: str, + filters: Dict[str, str], + page: int, + size: int, + sort_by: str = "relevance", + ) -> Tuple[List[Dict[str, Any]], int, Dict[str, Any]]: + """Perform the publisher search.""" + + # Get index names + index_names = self.get_index_names() + + if not index_names: + return [], 0, {} + + # Create multi-index search + search = Search(index=index_names) + + # Build and apply query + q = self.build_query(query) + search = search.query(q) + + # Apply filters + search = self.apply_filters(search, filters) + + # Apply sorting + search = self.apply_sorting(search, sort_by) + + # Build aggregations + search = self.build_aggregations(search) + + # Pagination + start = (page - 1) * size + search = search[start : start + size] + + # Execute search + try: + response = search.execute() + except Exception as e: + self.logger.error("publisher_search_error", error=str(e), exc_info=True) + return [], 0, {} + + # Process results + results = [] + for hit in response: + result = hit.to_dict() + result["_score"] = hit.meta.score + result["_index"] = hit.meta.index + results.append(result) + + # Process aggregations + aggregations: Dict[str, Any] = {} + if hasattr(response, "aggregations"): + aggs_dict = response.aggregations.to_dict() + + for agg_name in ["publisher_type", "sectors", "organization_types", "locations"]: + if agg_name in aggs_dict: + aggregations[agg_name] = {} + for bucket in aggs_dict[agg_name]["buckets"]: + aggregations[agg_name][bucket["key"]] = bucket["doc_count"] + + total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results) + + return results, total, aggregations + + @trace_method(name="get", attributes={"component": "publisher_search"}) + @track_metrics(name="publisher_search") + def get(self, request: Any) -> Response: + """Handle GET request and return search results.""" + try: + query: str = request.GET.get("query", "") + page: int = int(request.GET.get("page", 1)) + size: int = int(request.GET.get("size", 10)) + sort_by: str = request.GET.get("sort", "relevance") + + # Handle filters + filters: Dict[str, str] = {} + for key, values in request.GET.lists(): + if key not in ["query", "page", "size", "sort"]: + if len(values) > 1: + filters[key] = ",".join(values) + else: + filters[key] = values[0] + + # Perform search + results, total, aggregations = self.perform_search(query, filters, page, size, sort_by) + + # Serialize results + serializer = self.serializer_class(results, many=True) + + return Response( + { + "results": serializer.data, + "total": total, + "page": page, + "size": size, + "aggregations": aggregations, + } + ) + + except Exception as e: + self.logger.error("publisher_search_error", error=str(e), exc_info=True) + return Response({"error": "An internal error has occurred."}, status=500) diff --git a/api/views/search_unified.py b/api/views/search_unified.py index 96a0df2..0d9d5c0 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -1,12 +1,10 @@ """Unified search view that searches across datasets, usecases, and aimodels.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple import structlog -from elasticsearch import Elasticsearch from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search -from elasticsearch_dsl.connections import connections from rest_framework import serializers from rest_framework.permissions import AllowAny from rest_framework.response import Response @@ -21,7 +19,9 @@ AIModelDocument, CollaborativeDocument, DatasetDocument, + OrganizationPublisherDocument, UseCaseDocument, + UserPublisherDocument, ) logger = structlog.get_logger(__name__) @@ -31,7 +31,9 @@ class UnifiedSearchResultSerializer(serializers.Serializer): """Serializer for unified search results.""" id = serializers.CharField() - type = serializers.CharField() # 'dataset', 'usecase', 'aimodel', or 'collaborative' + type = ( + serializers.CharField() + ) # 'dataset', 'usecase', 'aimodel', 'collaborative', or 'publisher' title = serializers.CharField() description = serializers.CharField() slug = serializers.CharField(required=False) @@ -84,6 +86,18 @@ class UserSerializer(serializers.Serializer): started_on = serializers.DateTimeField(required=False) completed_on = serializers.DateTimeField(required=False) + # Publisher specific + publisher_type = serializers.CharField(required=False) # 'organization' or 'user' + published_datasets_count = serializers.IntegerField(required=False) + published_usecases_count = serializers.IntegerField(required=False) + members_count = serializers.IntegerField(required=False) + contributed_sectors_count = serializers.IntegerField(required=False) + homepage = serializers.CharField(required=False) + bio = serializers.CharField(required=False) + profile_picture = serializers.CharField(required=False) + username = serializers.CharField(required=False) + full_name = serializers.CharField(required=False) + class UnifiedSearch(APIView): """View for unified search across datasets, usecases, and aimodels.""" @@ -122,6 +136,16 @@ def _get_index_names(self, types_list: List[str]) -> List[str]: ) index_names.append(collaborative_index) + if "publisher" in types_list: + org_publisher_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.OrganizationPublisherDocument", + "organization_publisher", + ) + user_publisher_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.UserPublisherDocument", "user_publisher" + ) + index_names.extend([org_publisher_index, user_publisher_index]) + return index_names def _build_unified_query(self, query: str) -> ESQ: @@ -134,16 +158,16 @@ def _build_unified_query(self, query: str) -> ESQ: ESQ( "multi_match", query=query, - fields=["title^3", "name^3", "display_name^3"], + fields=["title^3", "name^3", "display_name^3", "full_name^3"], fuzziness="AUTO", ), ESQ( "multi_match", query=query, - fields=["description^2", "summary^2"], + fields=["description^2", "summary^2", "bio^2"], fuzziness="AUTO", ), - ESQ("multi_match", query=query, fields=["tags^2"], fuzziness="AUTO"), + ESQ("multi_match", query=query, fields=["tags^2", "sectors^2"], fuzziness="AUTO"), ] # Type-specific nested queries @@ -288,6 +312,8 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["type"] = "aimodel" elif "collaborative" in index_name: result["type"] = "collaborative" + elif "publisher" in index_name: + result["type"] = "publisher" else: result["type"] = "unknown" @@ -316,6 +342,17 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["description"] = result.get("summary", "") if "title" not in result: result["title"] = "" + elif result["type"] == "publisher": + # For publishers, use 'name' as title and handle description + if "name" in result: + result["title"] = result.get("name", "") + if "bio" in result and result.get("bio"): + result["description"] = result.get("bio", "") + elif "description" not in result or not result.get("description"): + result["description"] = "" + # Ensure status field exists (publishers don't have traditional status) + if "status" not in result: + result["status"] = "active" else: # dataset if "title" not in result: result["title"] = "" @@ -385,6 +422,11 @@ def perform_unified_search( aggregations["types"]["aimodel"] = bucket["doc_count"] elif "collaborative" in index_name: aggregations["types"]["collaborative"] = bucket["doc_count"] + elif "publisher" in index_name: + # Combine both organization and user publisher counts + if "publisher" not in aggregations["types"]: + aggregations["types"]["publisher"] = 0 + aggregations["types"]["publisher"] += bucket["doc_count"] # Process other aggregations for agg_name in ["tags", "sectors", "geographies", "status"]: @@ -405,7 +447,7 @@ def get(self, request: Any) -> Response: page: int = int(request.GET.get("page", 1)) size: int = int(request.GET.get("size", 10)) entity_types: str = request.GET.get( - "types", "dataset,usecase,aimodel,collaborative" + "types", "dataset,usecase,aimodel,collaborative,publisher" ) # Which entity types to search # Parse entity types diff --git a/search/documents/__init__.py b/search/documents/__init__.py index 64d15f8..96a2345 100644 --- a/search/documents/__init__.py +++ b/search/documents/__init__.py @@ -1,4 +1,8 @@ from search.documents.aimodel_document import AIModelDocument from search.documents.collaborative_document import CollaborativeDocument from search.documents.dataset_document import DatasetDocument +from search.documents.publisher_document import ( + OrganizationPublisherDocument, + UserPublisherDocument, +) from search.documents.usecase_document import UseCaseDocument diff --git a/search/documents/publisher_document.py b/search/documents/publisher_document.py new file mode 100644 index 0000000..881547b --- /dev/null +++ b/search/documents/publisher_document.py @@ -0,0 +1,360 @@ +from typing import Any, Dict, List, Optional, Union + +from django_elasticsearch_dsl import Document, Index, KeywordField, fields + +from api.models import Dataset, Organization, Sector, UseCase +from api.utils.enums import DatasetStatus, UseCaseStatus +from authorization.models import User +from DataSpace import settings +from search.documents.analysers import html_strip, ngram_analyser + +INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[__name__]) +INDEX.settings(number_of_shards=1, number_of_replicas=0) + + +class PublisherDocument(Document): + """Elasticsearch document for Publisher (Organization and User) models.""" + + # Common fields for both organizations and users + name = fields.TextField( + analyzer=ngram_analyser, + fields={ + "raw": KeywordField(multi=False), + }, + ) + + description = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + + publisher_type = fields.KeywordField() # 'organization' or 'user' + + # Organization specific fields + logo = fields.TextField(analyzer=ngram_analyser) + homepage = fields.TextField(analyzer=ngram_analyser) + contact_email = fields.KeywordField() + organization_types = fields.KeywordField() + github_profile = fields.TextField(analyzer=ngram_analyser) + linkedin_profile = fields.TextField(analyzer=ngram_analyser) + twitter_profile = fields.TextField(analyzer=ngram_analyser) + location = fields.TextField(analyzer=ngram_analyser) + + # User specific fields + bio = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + profile_picture = fields.TextField(analyzer=ngram_analyser) + username = fields.KeywordField() + email = fields.KeywordField() + first_name = fields.TextField(analyzer=ngram_analyser) + last_name = fields.TextField(analyzer=ngram_analyser) + full_name = fields.TextField(analyzer=ngram_analyser) + + # Common metadata + slug = fields.KeywordField() + created = fields.DateField() + modified = fields.DateField() + + # Computed fields + published_datasets_count = fields.IntegerField() + published_usecases_count = fields.IntegerField() + members_count = fields.IntegerField() # Only for organizations + contributed_sectors_count = fields.IntegerField() + + # For search and filtering + sectors = fields.TextField( + attr="sectors_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + def prepare_name(self, instance: Union[Organization, User]) -> str: + """Prepare name field for indexing.""" + if isinstance(instance, Organization): + return getattr(instance, "name", "") + else: # User + return getattr(instance, "full_name", "") or getattr(instance, "username", "") + + def prepare_description(self, instance: Union[Organization, User]) -> str: + """Prepare description field for indexing.""" + if isinstance(instance, Organization): + return instance.description or "" + else: # User + return instance.bio or "" + + def prepare_publisher_type(self, instance: Union[Organization, User]) -> str: + """Determine publisher type.""" + return "organization" if isinstance(instance, Organization) else "user" + + def prepare_logo(self, instance: Union[Organization, User]) -> str: + """Prepare logo/profile picture URL.""" + if isinstance(instance, Organization): + logo = getattr(instance, "logo", None) + return str(logo.url) if logo and hasattr(logo, "url") else "" + else: # User + profile_picture = getattr(instance, "profile_picture", None) + return ( + str(profile_picture.url) + if profile_picture and hasattr(profile_picture, "url") + else "" + ) + + def prepare_slug(self, instance: Union[Organization, User]) -> str: + """Prepare slug field.""" + if isinstance(instance, Organization): + return getattr(instance, "slug", "") or "" + else: # User + return str(getattr(instance, "id", "")) # Users don't have slugs, use ID + + def prepare_full_name(self, instance: Union[Organization, User]) -> str: + """Prepare full name for users.""" + if isinstance(instance, User): + first_name = getattr(instance, "first_name", "") + last_name = getattr(instance, "last_name", "") + if first_name and last_name: + return f"{first_name} {last_name}" + elif first_name: + return first_name + elif last_name: + return last_name + else: + return getattr(instance, "username", "") + return "" + + def prepare_published_datasets_count(self, instance: Union[Organization, User]) -> int: + """Get count of published datasets.""" + try: + if isinstance(instance, Organization): + return Dataset.objects.filter( + organization_id=instance.id, status=DatasetStatus.PUBLISHED.value + ).count() + else: # User + return Dataset.objects.filter( + user_id=instance.id, status=DatasetStatus.PUBLISHED.value + ).count() + except Exception: + return 0 + + def prepare_published_usecases_count(self, instance: Union[Organization, User]) -> int: + """Get count of published use cases.""" + try: + if isinstance(instance, Organization): + from django.db.models import Q + + use_cases = UseCase.objects.filter( + ( + Q(organization__id=instance.id) + | Q(usecaseorganizationrelationship__organization_id=instance.id) + ), + status=UseCaseStatus.PUBLISHED.value, + ).distinct() + return use_cases.count() + else: # User + return UseCase.objects.filter( + user_id=instance.id, status=UseCaseStatus.PUBLISHED.value + ).count() + except Exception: + return 0 + + def prepare_members_count(self, instance: Union[Organization, User]) -> int: + """Get count of members (only for organizations).""" + if isinstance(instance, Organization): + try: + from authorization.models import OrganizationMembership + + return OrganizationMembership.objects.filter(organization_id=instance.id).count() + except Exception: + return 0 + return 0 + + def prepare_contributed_sectors_count(self, instance: Union[Organization, User]) -> int: + """Get count of sectors contributed to.""" + try: + from api.models import Sector + + if isinstance(instance, Organization): + # Get sectors from published datasets + dataset_sectors = ( + Sector.objects.filter( + datasets__organization_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Get sectors from published use cases + usecase_sectors = ( + Sector.objects.filter( + usecases__usecaseorganizationrelationship__organization_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + else: # User + # Get sectors from published datasets + dataset_sectors = ( + Sector.objects.filter( + datasets__user_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Get sectors from published use cases + usecase_sectors = ( + Sector.objects.filter( + usecases__user_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Combine and deduplicate sectors + sector_ids = set(dataset_sectors) + sector_ids.update(usecase_sectors) + + return len(sector_ids) + except Exception: + return 0 + + def prepare_sectors_indexing(self, instance: Union[Organization, User]) -> List[str]: + """Prepare sectors for indexing.""" + try: + from api.models import Sector + + if isinstance(instance, Organization): + # Get sectors from published datasets + dataset_sectors = Sector.objects.filter( + datasets__organization_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ).distinct() + + # Get sectors from published use cases + usecase_sectors = Sector.objects.filter( + usecases__usecaseorganizationrelationship__organization_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ).distinct() + else: # User + # Get sectors from published datasets + dataset_sectors = Sector.objects.filter( + datasets__user_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ).distinct() + + # Get sectors from published use cases + usecase_sectors = Sector.objects.filter( + usecases__user_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ).distinct() + + # Combine and deduplicate sectors + all_sectors = set(dataset_sectors) | set(usecase_sectors) + return [ + getattr(sector, "name", "") for sector in all_sectors if hasattr(sector, "name") + ] + except Exception: + return [] + + def should_index_object(self, obj: Union[Organization, User]) -> bool: + """Check if the object should be indexed (has published content).""" + try: + if isinstance(obj, Organization): + has_datasets = Dataset.objects.filter( + organization_id=obj.id, status=DatasetStatus.PUBLISHED.value + ).exists() + has_usecases = UseCase.objects.filter( + organization_id=obj.id, status=UseCaseStatus.PUBLISHED.value + ).exists() + return has_datasets or has_usecases + else: # User + has_datasets = Dataset.objects.filter( + user_id=obj.id, status=DatasetStatus.PUBLISHED.value + ).exists() + has_usecases = UseCase.objects.filter( + user_id=obj.id, status=UseCaseStatus.PUBLISHED.value + ).exists() + return has_datasets or has_usecases + except Exception: + return False + + def get_instances_from_related( + self, + related_instance: Union[Dataset, UseCase], + ) -> Optional[List[Union[Organization, User]]]: + """Get Publisher instances from related models.""" + publishers: List[Union[Organization, User]] = [] + + if isinstance(related_instance, Dataset): + if related_instance.organization: + publishers.append(related_instance.organization) + if related_instance.user: + publishers.append(related_instance.user) + elif isinstance(related_instance, UseCase): + if related_instance.organization: + publishers.append(related_instance.organization) + if related_instance.user: + publishers.append(related_instance.user) + + return publishers if publishers else None + + +@INDEX.doc_type +class OrganizationPublisherDocument(PublisherDocument): + """Organization-specific publisher document.""" + + class Django: + """Django model configuration.""" + + model = Organization + + fields = [ + "id", + ] + + related_models = [ + Dataset, + UseCase, + Sector, + ] + + +@INDEX.doc_type +class UserPublisherDocument(PublisherDocument): + """User-specific publisher document.""" + + class Django: + """Django model configuration.""" + + model = User + + fields = [ + "id", + ] + + related_models = [ + Dataset, + UseCase, + Sector, + ] + + def prepare_created(self, instance: User) -> Any: + """Map date_joined to created for consistency.""" + return instance.date_joined + + def prepare_modified(self, instance: User) -> Any: + """Map last_login to modified for consistency.""" + return instance.last_login From 76dd033c76bc335739ea41136539d70174ac53c8 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 16 Feb 2026 23:59:35 +0530 Subject: [PATCH 116/127] proces profile pic before indexing --- search/documents/publisher_document.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/search/documents/publisher_document.py b/search/documents/publisher_document.py index 881547b..4805bf1 100644 --- a/search/documents/publisher_document.py +++ b/search/documents/publisher_document.py @@ -58,8 +58,6 @@ class PublisherDocument(Document): # Common metadata slug = fields.KeywordField() - created = fields.DateField() - modified = fields.DateField() # Computed fields published_datasets_count = fields.IntegerField() @@ -109,6 +107,17 @@ def prepare_logo(self, instance: Union[Organization, User]) -> str: else "" ) + def prepare_profile_picture(self, instance: Union[Organization, User]) -> str: + """Prepare profile picture URL for users.""" + if isinstance(instance, User): + profile_picture = getattr(instance, "profile_picture", None) + return ( + str(profile_picture.url) + if profile_picture and hasattr(profile_picture, "url") + else "" + ) + return "" + def prepare_slug(self, instance: Union[Organization, User]) -> str: """Prepare slug field.""" if isinstance(instance, Organization): @@ -323,6 +332,8 @@ class Django: fields = [ "id", + "created", + "modified", ] related_models = [ @@ -343,6 +354,8 @@ class Django: fields = [ "id", + "date_joined", + "last_login", ] related_models = [ From 0bdfd0188d4dd3b1c110adc37e72dcdda78452ce Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 17 Feb 2026 00:09:48 +0530 Subject: [PATCH 117/127] remove base publisher document --- DataSpace/settings.py | 1 - search/documents/publisher_document.py | 13 ++++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/DataSpace/settings.py b/DataSpace/settings.py index 7996d81..ee5a2a0 100644 --- a/DataSpace/settings.py +++ b/DataSpace/settings.py @@ -280,7 +280,6 @@ "search.documents.collaborative_document": "collaborative", "search.documents.publisher_document.OrganizationPublisherDocument": "organization_publisher", "search.documents.publisher_document.UserPublisherDocument": "user_publisher", - "search.documents.publisher_document": "publisher", } diff --git a/search/documents/publisher_document.py b/search/documents/publisher_document.py index 4805bf1..d3abbc8 100644 --- a/search/documents/publisher_document.py +++ b/search/documents/publisher_document.py @@ -8,9 +8,6 @@ from DataSpace import settings from search.documents.analysers import html_strip, ngram_analyser -INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[__name__]) -INDEX.settings(number_of_shards=1, number_of_replicas=0) - class PublisherDocument(Document): """Elasticsearch document for Publisher (Organization and User) models.""" @@ -321,10 +318,13 @@ def get_instances_from_related( return publishers if publishers else None -@INDEX.doc_type class OrganizationPublisherDocument(PublisherDocument): """Organization-specific publisher document.""" + class Index: + name = settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.OrganizationPublisherDocument"] + settings = {"number_of_shards": 1, "number_of_replicas": 0} + class Django: """Django model configuration.""" @@ -343,10 +343,13 @@ class Django: ] -@INDEX.doc_type class UserPublisherDocument(PublisherDocument): """User-specific publisher document.""" + class Index: + name = settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.UserPublisherDocument"] + settings = {"number_of_shards": 1, "number_of_replicas": 0} + class Django: """Django model configuration.""" From 13f101d8be8b570f892c46e30927958fecd99afd Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 17 Feb 2026 00:15:48 +0530 Subject: [PATCH 118/127] use decorators to index publishers --- search/documents/publisher_document.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/search/documents/publisher_document.py b/search/documents/publisher_document.py index d3abbc8..b63744d 100644 --- a/search/documents/publisher_document.py +++ b/search/documents/publisher_document.py @@ -8,6 +8,13 @@ from DataSpace import settings from search.documents.analysers import html_strip, ngram_analyser +# Create separate indices for each publisher document type +ORG_INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.OrganizationPublisherDocument"]) +ORG_INDEX.settings(number_of_shards=1, number_of_replicas=0) + +USER_INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.UserPublisherDocument"]) +USER_INDEX.settings(number_of_shards=1, number_of_replicas=0) + class PublisherDocument(Document): """Elasticsearch document for Publisher (Organization and User) models.""" @@ -318,13 +325,10 @@ def get_instances_from_related( return publishers if publishers else None +@ORG_INDEX.doc_type class OrganizationPublisherDocument(PublisherDocument): """Organization-specific publisher document.""" - class Index: - name = settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.OrganizationPublisherDocument"] - settings = {"number_of_shards": 1, "number_of_replicas": 0} - class Django: """Django model configuration.""" @@ -343,13 +347,10 @@ class Django: ] +@USER_INDEX.doc_type class UserPublisherDocument(PublisherDocument): """User-specific publisher document.""" - class Index: - name = settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.UserPublisherDocument"] - settings = {"number_of_shards": 1, "number_of_replicas": 0} - class Django: """Django model configuration.""" From 7d9cab421442fd129457785200c7926de1598968 Mon Sep 17 00:00:00 2001 From: dc Date: Tue, 17 Feb 2026 00:19:26 +0530 Subject: [PATCH 119/127] update normalization for publisher documents --- api/views/search_unified.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/api/views/search_unified.py b/api/views/search_unified.py index 0d9d5c0..8a4e23b 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -332,7 +332,6 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["title"] = "" if "description" not in result: result["description"] = "" - # AIModel uses created_at/updated_at if "created_at" in result: result["created"] = result["created_at"] if "updated_at" in result: @@ -343,16 +342,20 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: if "title" not in result: result["title"] = "" elif result["type"] == "publisher": - # For publishers, use 'name' as title and handle description if "name" in result: result["title"] = result.get("name", "") if "bio" in result and result.get("bio"): result["description"] = result.get("bio", "") elif "description" not in result or not result.get("description"): result["description"] = "" - # Ensure status field exists (publishers don't have traditional status) if "status" not in result: result["status"] = "active" + if "tags" not in result: + result["tags"] = [] + if "sectors" not in result or result.get("sectors") is None: + result["sectors"] = [] + if "geographies" not in result: + result["geographies"] = [] else: # dataset if "title" not in result: result["title"] = "" @@ -409,7 +412,6 @@ def perform_unified_search( if hasattr(response, "aggregations"): aggs_dict = response.aggregations.to_dict() - # Process types aggregation if "types" in aggs_dict: aggregations["types"] = {} for bucket in aggs_dict["types"]["buckets"]: @@ -428,7 +430,6 @@ def perform_unified_search( aggregations["types"]["publisher"] = 0 aggregations["types"]["publisher"] += bucket["doc_count"] - # Process other aggregations for agg_name in ["tags", "sectors", "geographies", "status"]: if agg_name in aggs_dict: aggregations[agg_name] = {} @@ -450,10 +451,8 @@ def get(self, request: Any) -> Response: "types", "dataset,usecase,aimodel,collaborative,publisher" ) # Which entity types to search - # Parse entity types types_list = [t.strip() for t in entity_types.split(",")] - # Handle filters filters: Dict[str, str] = {} for key, values in request.GET.lists(): if key not in ["query", "page", "size", "types"]: From 6cfbe7960d20eaa05cc038b758a293f66a715b0c Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:05:39 +0530 Subject: [PATCH 120/127] add caching to geography queries --- api/models/Geography.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/api/models/Geography.py b/api/models/Geography.py index 56de119..6ba39e4 100644 --- a/api/models/Geography.py +++ b/api/models/Geography.py @@ -1,16 +1,17 @@ from typing import List, Optional +from django.core.cache import cache from django.db import models from api.utils.enums import GeoTypes +GEOGRAPHY_CACHE_TTL = 60 * 60 * 6 # 6 hours + class Geography(models.Model): id = models.AutoField(primary_key=True) name = models.CharField(max_length=75, unique=True) - code = models.CharField( - max_length=100, null=True, blank=True, unique=False, default="" - ) + code = models.CharField(max_length=100, null=True, blank=True, unique=False, default="") type = models.CharField(max_length=20, choices=GeoTypes.choices) parent_id = models.ForeignKey( "self", on_delete=models.CASCADE, null=True, blank=True, default=None @@ -37,9 +38,7 @@ def get_all_descendant_names(self) -> List[str]: return descendants @classmethod - def get_geography_names_with_descendants( - cls, geography_names: List[str] - ) -> List[str]: + def get_geography_names_with_descendants(cls, geography_names: List[str]) -> List[str]: """ Given a list of geography names, return all names including their descendants. This is a helper method for filtering that expands parent geographies to include children. @@ -50,6 +49,11 @@ def get_geography_names_with_descendants( Returns: List of geography names including all descendants """ + cache_key = f"geo_descendants:{':'.join(sorted(geography_names))}" + cached: Optional[List[str]] = cache.get(cache_key) + if cached is not None: + return cached + all_names = set() for name in geography_names: @@ -60,7 +64,9 @@ def get_geography_names_with_descendants( # If geography doesn't exist, just add the name as-is all_names.add(name) - return list(all_names) + result = list(all_names) + cache.set(cache_key, result, timeout=GEOGRAPHY_CACHE_TTL) + return result class Meta: db_table = "geography" From 5c20e6e9113454b294ea8c3a504f371c31fd5212 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:05:55 +0530 Subject: [PATCH 121/127] add caching to stats query --- api/schema/stats_schema.py | 41 +++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/api/schema/stats_schema.py b/api/schema/stats_schema.py index 6810a64..4ee8d1f 100644 --- a/api/schema/stats_schema.py +++ b/api/schema/stats_schema.py @@ -2,6 +2,7 @@ import strawberry import strawberry_django +from django.core.cache import cache from django.db.models import Count, Q from strawberry.types import Info @@ -12,6 +13,9 @@ from api.utils.graphql_telemetry import trace_resolver from authorization.models import User +STATS_CACHE_KEY = "platform_stats" +STATS_CACHE_TTL = 60 * 10 # 10 minutes + @strawberry.type class StatsType: @@ -31,20 +35,20 @@ class Query: @trace_resolver(name="stats", attributes={"component": "stats"}) def stats(self, info: Info) -> StatsType: """Get platform statistics""" + cached = cache.get(STATS_CACHE_KEY) + if cached: + return StatsType(**cached) + # Count total users total_users = User.objects.count() # Count published datasets - total_published_datasets = Dataset.objects.filter( - status=DatasetStatus.PUBLISHED - ).count() + total_published_datasets = Dataset.objects.filter(status=DatasetStatus.PUBLISHED).count() # Count publishers (organizations and individuals who have published datasets) # First, get organizations that have published datasets org_publishers = ( - Organization.objects.filter(datasets__status=DatasetStatus.PUBLISHED) - .distinct() - .count() + Organization.objects.filter(datasets__status=DatasetStatus.PUBLISHED).distinct().count() ) # Then, get individual users who have published datasets @@ -61,15 +65,16 @@ def stats(self, info: Info) -> StatsType: total_publishers = org_publishers + individual_publishers # Count published usecases - total_published_usecases = UseCase.objects.filter( - status=UseCaseStatus.PUBLISHED - ).count() - - return StatsType( - total_users=total_users, - total_published_datasets=total_published_datasets, - total_publishers=total_publishers, - total_organizations=org_publishers, - total_individuals=individual_publishers, - total_published_usecases=total_published_usecases, - ) + total_published_usecases = UseCase.objects.filter(status=UseCaseStatus.PUBLISHED).count() + + stats_data = { + "total_users": total_users, + "total_published_datasets": total_published_datasets, + "total_publishers": total_publishers, + "total_organizations": org_publishers, + "total_individuals": individual_publishers, + "total_published_usecases": total_published_usecases, + } + cache.set(STATS_CACHE_KEY, stats_data, timeout=STATS_CACHE_TTL) + + return StatsType(**stats_data) From 4ecef7a73a0e530663c50b4f25af2eb4d47809c4 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:06:12 +0530 Subject: [PATCH 122/127] add caching to aimodel detail --- api/views/aimodel_detail.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/api/views/aimodel_detail.py b/api/views/aimodel_detail.py index 96612dd..411aaca 100644 --- a/api/views/aimodel_detail.py +++ b/api/views/aimodel_detail.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, List, Optional +from django.core.cache import cache from rest_framework import serializers, status from rest_framework.permissions import AllowAny from rest_framework.request import Request @@ -11,6 +12,8 @@ from api.models.AIModel import AIModel, ModelEndpoint +AIMODEL_DETAIL_CACHE_TTL = 60 * 15 # 15 minutes + logger = logging.getLogger(__name__) @@ -119,11 +122,17 @@ class AIModelDetailView(APIView): def get(self, request: Request, model_id: str) -> Response: """Get AI model details.""" try: + cache_key = f"aimodel_detail:{model_id}" + cached = cache.get(cache_key) + if cached: + return Response(cached) + model = AIModel.objects.prefetch_related( "tags", "sectors", "geographies", "endpoints", "organization", "user" ).get(id=model_id) serializer = AIModelDetailSerializer(model) + cache.set(cache_key, serializer.data, timeout=AIMODEL_DETAIL_CACHE_TTL) return Response(serializer.data) except AIModel.DoesNotExist: From dacb5dfdf87af647fea5bae5c42274f9a1d47392 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:06:33 +0530 Subject: [PATCH 123/127] add caching to role query --- api/views/auditor.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/api/views/auditor.py b/api/views/auditor.py index e3ba4f3..40535ec 100644 --- a/api/views/auditor.py +++ b/api/views/auditor.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, List, Optional +from django.core.cache import cache from django.db import transaction from rest_framework import status, views from rest_framework.permissions import IsAuthenticated @@ -12,6 +13,8 @@ from api.models import Organization from authorization.models import OrganizationMembership, Role, User +ROLE_CACHE_TTL = 60 * 60 # 1 hour + logger = logging.getLogger(__name__) @@ -46,8 +49,16 @@ def _check_admin_permission(self, user: User, organization: Organization) -> boo def _get_auditor_role(self) -> Optional[Role]: """Get the auditor role.""" + cached_id = cache.get("role_id:auditor") + if cached_id: + try: + return Role.objects.get(pk=cached_id) + except Role.DoesNotExist: + cache.delete("role_id:auditor") try: - return Role.objects.get(name="auditor") + role = Role.objects.get(name="auditor") + cache.set("role_id:auditor", role.pk, timeout=ROLE_CACHE_TTL) + return role except Role.DoesNotExist: logger.error("Auditor role not found. Please run migrations.") return None From bfc2ecb01706ec4d6e794e3008ac78e94140e50a Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:06:51 +0530 Subject: [PATCH 124/127] enable cachin gin elastic queries --- api/views/paginated_elastic_view.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/views/paginated_elastic_view.py b/api/views/paginated_elastic_view.py index 7a952a3..75e4465 100644 --- a/api/views/paginated_elastic_view.py +++ b/api/views/paginated_elastic_view.py @@ -56,9 +56,8 @@ def get(self, request: HttpRequest) -> Response: cache_key = self._generate_cache_key(request) cached_result: Optional[Dict[str, Any]] = cache.get(cache_key) - # TODO: Fix cache issues on different model updates - # if cached_result: - # return Response(cached_result) + if cached_result: + return Response(cached_result) # Original search logic query: str = request.GET.get("query", "") From eae7e14f3b5607f7117d5c721abeb752b5d5132d Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:09:31 +0530 Subject: [PATCH 125/127] add caching to all search apis --- api/views/search_collaborative.py | 37 ++++++++++++--- api/views/search_dataset.py | 22 +++++++-- api/views/search_unified.py | 21 +++++++++ api/views/search_usecase.py | 75 +++++++++++++++---------------- 4 files changed, 109 insertions(+), 46 deletions(-) diff --git a/api/views/search_collaborative.py b/api/views/search_collaborative.py index 17732c6..38d4507 100644 --- a/api/views/search_collaborative.py +++ b/api/views/search_collaborative.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import structlog +from django.core.cache import cache from elasticsearch_dsl import A from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search @@ -14,6 +15,8 @@ from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView from search.documents import CollaborativeDocument +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + logger = structlog.get_logger(__name__) @@ -147,6 +150,12 @@ def __init__(self, **kwargs: Any) -> None: attributes={"component": "search_collaborative"}, ) def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "collaborative_search_metadata_config" + ) + if cached: + return cached + searchable_fields = [ "title", "summary", @@ -173,7 +182,9 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: for metadata in filterable_metadata: aggregations[f"metadata.{metadata.label}"] = "terms" # type: ignore - return searchable_fields, aggregations + result = (searchable_fields, aggregations) + cache.set("collaborative_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result @trace_method(name="add_aggregations", attributes={"component": "search_collaborative"}) def add_aggregations(self, search: Search) -> Search: @@ -189,8 +200,17 @@ def add_aggregations(self, search: Search) -> Search: ) if aggregate_fields: - metadata_qs = Metadata.objects.filter(filterable=True) - filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + filterable_metadata: List[str] = ( + cache.get("collaborative_filterable_metadata_labels") or [] + ) + if not filterable_metadata: + metadata_qs = Metadata.objects.filter(filterable=True) + filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + cache.set( + "collaborative_filterable_metadata_labels", + filterable_metadata, + timeout=METADATA_CACHE_TTL, + ) metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") composite_agg = A( @@ -277,8 +297,15 @@ def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQu @trace_method(name="add_filters", attributes={"component": "search_collaborative"}) def add_filters(self, filters: Dict[str, str], search: Search) -> Search: - non_filter_metadata = Metadata.objects.filter(filterable=False).all() - excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + excluded_labels: List[str] = cache.get("collaborative_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "collaborative_non_filter_metadata_labels", + excluded_labels, + timeout=METADATA_CACHE_TTL, + ) for filter in filters: if filter in excluded_labels: diff --git a/api/views/search_dataset.py b/api/views/search_dataset.py index 9dafe09..680d509 100644 --- a/api/views/search_dataset.py +++ b/api/views/search_dataset.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast import structlog +from django.core.cache import cache from elasticsearch_dsl import A from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search @@ -14,6 +15,8 @@ from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView from search.documents import DatasetDocument +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + logger = structlog.get_logger(__name__) @@ -147,6 +150,12 @@ def __init__(self, **kwargs: Any) -> None: ) def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: """Get searchable fields and aggregations for the search.""" + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "dataset_search_metadata_config" + ) + if cached: + return cached + enabled_metadata = Metadata.objects.filter(enabled=True).all() searchable_fields: List[str] = [] searchable_fields.extend( @@ -170,7 +179,9 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: for metadata in enabled_metadata: # type: Metadata if metadata.filterable: aggregations[f"metadata.{metadata.label}"] = "terms" - return searchable_fields, aggregations + result = (searchable_fields, aggregations) + cache.set("dataset_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result @trace_method(name="add_aggregations", attributes={"component": "search_dataset"}) def add_aggregations(self, search: Search) -> Search: @@ -248,8 +259,13 @@ def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQu @trace_method(name="add_filters", attributes={"component": "search_dataset"}) def add_filters(self, filters: Dict[str, str], search: Search) -> Search: """Add filters to the search query.""" - non_filter_metadata = Metadata.objects.filter(filterable=False).all() - excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + excluded_labels: List[str] = cache.get("dataset_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "dataset_non_filter_metadata_labels", excluded_labels, timeout=METADATA_CACHE_TTL + ) for filter in filters: if filter in excluded_labels: diff --git a/api/views/search_unified.py b/api/views/search_unified.py index 8a4e23b..b84c6da 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Tuple import structlog +from django.core.cache import cache from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search from rest_framework import serializers @@ -13,6 +14,7 @@ from api.models import Dataset, Geography, Metadata, UseCase from api.models.AIModel import AIModel from api.models.Collaborative import Collaborative +from api.signals.dataset_signals import SEARCH_CACHE_VERSION_KEY from api.utils.telemetry_utils import trace_method from DataSpace import settings from search.documents import ( @@ -440,10 +442,27 @@ def perform_unified_search( return results, total, aggregations + def _generate_unified_cache_key(self, request: Any) -> str: + """Generate a unique cache key for unified search based on request parameters.""" + params: Dict[str, str] = { + "query": request.GET.get("query", ""), + "page": request.GET.get("page", "1"), + "size": request.GET.get("size", "10"), + "types": request.GET.get("types", "dataset,usecase,aimodel,collaborative,publisher"), + "filters": str(sorted(request.GET.dict().items())), + "version": str(cache.get(SEARCH_CACHE_VERSION_KEY, 0)), + } + return f"unified_search:{hash(frozenset(params.items()))}" + @trace_method(name="get", attributes={"component": "unified_search"}) def get(self, request: Any) -> Response: """Handle GET request and return unified search results.""" try: + cache_key = self._generate_unified_cache_key(request) + cached_result = cache.get(cache_key) + if cached_result: + return Response(cached_result) + query: str = request.GET.get("query", "") page: int = int(request.GET.get("page", 1)) size: int = int(request.GET.get("size", 10)) @@ -476,6 +495,8 @@ def get(self, request: Any) -> Response: "types_searched": types_list, } + cache.set(cache_key, result, timeout=3600) + return Response(result) except Exception as e: diff --git a/api/views/search_usecase.py b/api/views/search_usecase.py index f5ed330..8782d40 100644 --- a/api/views/search_usecase.py +++ b/api/views/search_usecase.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast import structlog +from django.core.cache import cache from elasticsearch_dsl import A from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search @@ -14,6 +15,8 @@ from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView from search.documents import UseCaseDocument +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + logger = structlog.get_logger(__name__) @@ -144,9 +147,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -155,6 +156,12 @@ def __init__(self, **kwargs: Any) -> None: ) def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: """Get searchable fields and aggregations for the search.""" + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "usecase_search_metadata_config" + ) + if cached: + return cached + searchable_fields = [ "title", "summary", @@ -181,7 +188,9 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: for metadata in filterable_metadata: aggregations[f"metadata.{metadata.label}"] = "terms" # type: ignore - return searchable_fields, aggregations + result = (searchable_fields, aggregations) + cache.set("usecase_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result @trace_method(name="add_aggregations", attributes={"component": "search_usecase"}) def add_aggregations(self, search: Search) -> Search: @@ -199,18 +208,21 @@ def add_aggregations(self, search: Search) -> Search: ) if aggregate_fields: - metadata_qs = Metadata.objects.filter(filterable=True) - filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + filterable_metadata: List[str] = cache.get("usecase_filterable_metadata_labels") or [] + if not filterable_metadata: + metadata_qs = Metadata.objects.filter(filterable=True) + filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + cache.set( + "usecase_filterable_metadata_labels", + filterable_metadata, + timeout=METADATA_CACHE_TTL, + ) metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") composite_agg = A( "composite", sources=[ - { - "metadata_label": { - "terms": {"field": "metadata.metadata_item.label"} - } - }, + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, {"metadata_value": {"terms": {"field": "metadata.value"}}}, ], size=10000, @@ -219,13 +231,7 @@ def add_aggregations(self, search: Search) -> Search: "filter", { # type: ignore[arg-type] "bool": { - "must": [ - { - "terms": { - "metadata.metadata_item.label": filterable_metadata - } - } - ] + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] } }, ) @@ -235,12 +241,8 @@ def add_aggregations(self, search: Search) -> Search: return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_usecase"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_usecase"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] @@ -256,9 +258,7 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), @@ -281,18 +281,14 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), ) ) else: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -301,8 +297,13 @@ def generate_q_expression( @trace_method(name="add_filters", attributes={"component": "search_usecase"}) def add_filters(self, filters: Dict[str, str], search: Search) -> Search: """Add filters to the search query.""" - non_filter_metadata = Metadata.objects.filter(filterable=False).all() - excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + excluded_labels: List[str] = cache.get("usecase_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "usecase_non_filter_metadata_labels", excluded_labels, timeout=METADATA_CACHE_TTL + ) for filter in filters: if filter in excluded_labels: @@ -335,9 +336,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter( "nested", path="metadata", - query={ - "bool": {"must": {"term": {f"metadata.value": filters[filter]}}} - }, + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, ) return search From f40237031a45ca4d25e942e29c2043a7a05d41f5 Mon Sep 17 00:00:00 2001 From: dc Date: Mon, 23 Feb 2026 19:10:15 +0530 Subject: [PATCH 126/127] cache user by kc token --- authorization/middleware_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/authorization/middleware_utils.py b/authorization/middleware_utils.py index 684f0a9..57dbe36 100644 --- a/authorization/middleware_utils.py +++ b/authorization/middleware_utils.py @@ -1,8 +1,10 @@ +import hashlib from typing import Any, Callable, Dict, List, Optional, cast import structlog from django.conf import settings from django.contrib.auth.models import AnonymousUser +from django.core.cache import cache from django.db import transaction from django.http import HttpRequest, HttpResponse from django.utils.functional import SimpleLazyObject @@ -16,6 +18,14 @@ logger = structlog.getLogger(__name__) +TOKEN_CACHE_TTL = 60 * 5 # 5 minutes + + +def _get_token_cache_key(token: str) -> str: + """Return a safe Redis key derived from the token without storing the raw token.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + return f"auth_token:{token_hash}" + def get_user_from_keycloak_token(request: HttpRequest) -> User: """ @@ -59,6 +69,17 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.debug("No token found, returning anonymous user") return cast(User, AnonymousUser()) + # Check cache before doing any validation or DB work + cache_key = _get_token_cache_key(token) + cached_user_id = cache.get(cache_key) + if cached_user_id: + try: + user = User.objects.get(id=cached_user_id) + logger.debug(f"Returning cached authenticated user: {user.username}") + return user + except User.DoesNotExist: + cache.delete(cache_key) + # Log token details for debugging logger.debug(f"Processing token of length: {len(token)}") logger.debug(f"Token type: {type(token)}") @@ -80,6 +101,7 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.debug(f"Valid Django JWT token for user_id: {user_id}") try: user = User.objects.get(id=user_id) + cache.set(cache_key, user.id, timeout=TOKEN_CACHE_TTL) logger.debug(f"Successfully authenticated user via Django JWT: {user.username}") return user except User.DoesNotExist: @@ -130,6 +152,7 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.warning("User synchronization failed, returning anonymous user") return cast(User, AnonymousUser()) + cache.set(cache_key, synced_user.id, timeout=TOKEN_CACHE_TTL) logger.debug( f"Successfully authenticated user: {synced_user.username} (ID: {synced_user.id})" ) From c223bc51162e72e1a3856492d780b64ce7450e75 Mon Sep 17 00:00:00 2001 From: dc Date: Thu, 12 Mar 2026 16:43:31 +0530 Subject: [PATCH 127/127] add document class to cache key --- api/views/paginated_elastic_view.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/views/paginated_elastic_view.py b/api/views/paginated_elastic_view.py index 75e4465..7f892bb 100644 --- a/api/views/paginated_elastic_view.py +++ b/api/views/paginated_elastic_view.py @@ -189,6 +189,7 @@ def get(self, request: HttpRequest) -> Response: def _generate_cache_key(self, request: HttpRequest) -> str: """Generate a unique cache key based on request parameters and cache version.""" params: Dict[str, str] = { + "document_type": self.document_class.__name__, "query": request.GET.get("query", ""), "page": request.GET.get("page", "1"), "size": request.GET.get("size", "10"),