diff --git a/.github/workflows/publish-sdk.yml b/.github/workflows/publish-sdk.yml index 5c3994ed..987545c8 100644 --- a/.github/workflows/publish-sdk.yml +++ b/.github/workflows/publish-sdk.yml @@ -23,6 +23,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + ref: ${{ github.ref }} - name: Set up Python uses: actions/setup-python@v4 @@ -35,13 +36,30 @@ jobs: pip install build twine pip install -e ".[dev]" - - name: Update version in setup.py + - name: Update version in __version__.py run: | - sed -i "s/version=\".*\"/version=\"${{ inputs.version }}\"/" setup.py + echo '"""Version information for DataSpace SDK."""' > dataspace_sdk/__version__.py + echo "" >> dataspace_sdk/__version__.py + echo "__version__ = \"${{ inputs.version }}\"" >> dataspace_sdk/__version__.py - - name: Update version in pyproject.toml + - name: Update version in pyproject.toml (if exists) run: | - sed -i "s/version = \".*\"/version = \"${{ inputs.version }}\"/" pyproject.toml + if [ -f pyproject.toml ]; then + sed -i "s/version = \".*\"/version = \"${{ inputs.version }}\"/" pyproject.toml + fi + + - name: Commit and push version changes + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add dataspace_sdk/__version__.py + if [ -f pyproject.toml ]; then git add pyproject.toml; fi + git commit -m "Bump SDK version to ${{ inputs.version }}" || echo "No changes to commit" + + # Get current branch name + BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) + echo "Pushing to branch: $BRANCH_NAME" + git push origin $BRANCH_NAME || echo "No changes to push" - name: Run tests run: | @@ -65,14 +83,6 @@ jobs: run: | twine upload dist/* - - name: Commit version changes - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - git add setup.py pyproject.toml - git commit -m "Bump version to ${{ inputs.version }}" || echo "No changes to commit" - git push || echo "No changes to push" - - name: Create GitHub Release uses: actions/create-release@v1 env: diff --git a/DataSpace/settings.py b/DataSpace/settings.py index feca5141..ee5a2a01 100644 --- a/DataSpace/settings.py +++ b/DataSpace/settings.py @@ -22,12 +22,15 @@ from .cache_settings import * -env = environ.Env(DEBUG=(bool, False)) -DEBUG = env.bool("DEBUG", default=True) # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent + +# Load .env file FIRST, before reading any env variables +env = environ.Env(DEBUG=(bool, False)) environ.Env.read_env(os.path.join(BASE_DIR, ".env")) +DEBUG = env.bool("DEBUG", default=True) + # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/ @@ -274,6 +277,9 @@ "search.documents.dataset_document": "dataset", "search.documents.usecase_document": "usecase", "search.documents.aimodel_document": "aimodel", + "search.documents.collaborative_document": "collaborative", + "search.documents.publisher_document.OrganizationPublisherDocument": "organization_publisher", + "search.documents.publisher_document.UserPublisherDocument": "user_publisher", } @@ -301,9 +307,7 @@ # Swagger Settings SWAGGER_SETTINGS = { - "SECURITY_DEFINITIONS": { - "Bearer": {"type": "apiKey", "name": "Authorization", "in": "header"} - } + "SECURITY_DEFINITIONS": {"Bearer": {"type": "apiKey", "name": "Authorization", "in": "header"}} } # Structured Logging Configuration @@ -370,17 +374,11 @@ # OpenTelemetry Sampling Configuration OTEL_TRACES_SAMPLER = "parentbased_traceidratio" -OTEL_TRACES_SAMPLER_ARG = os.getenv( - "OTEL_TRACES_SAMPLER_ARG", "1.0" -) # Sample 100% in dev +OTEL_TRACES_SAMPLER_ARG = os.getenv("OTEL_TRACES_SAMPLER_ARG", "1.0") # Sample 100% in dev # OpenTelemetry Metrics Configuration -OTEL_METRIC_EXPORT_INTERVAL_MILLIS = int( - os.getenv("OTEL_METRIC_EXPORT_INTERVAL_MILLIS", "30000") -) -OTEL_METRIC_EXPORT_TIMEOUT_MILLIS = int( - os.getenv("OTEL_METRIC_EXPORT_TIMEOUT_MILLIS", "30000") -) +OTEL_METRIC_EXPORT_INTERVAL_MILLIS = int(os.getenv("OTEL_METRIC_EXPORT_INTERVAL_MILLIS", "30000")) +OTEL_METRIC_EXPORT_TIMEOUT_MILLIS = int(os.getenv("OTEL_METRIC_EXPORT_TIMEOUT_MILLIS", "30000")) # OpenTelemetry Instrumentation Configuration OTEL_PYTHON_DJANGO_INSTRUMENT = True diff --git a/Dockerfile b/Dockerfile index 30967ee2..19a08823 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,47 @@ FROM python:3.10 ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONUNBUFFERED=1 -RUN echo 'deb http://archive.debian.org/debian stretch main contrib non-free' >> /etc/apt/sources.list && \ - apt-get update && \ + +# Install system dependencies +RUN apt-get update && \ apt-get autoremove -y && \ - apt-get install -y libssl1.0-dev curl git nano wget && \ - apt-get install -y gconf-service libasound2 libatk1.0-0 libc6 libcairo2 libcups2 libdbus-1-3 libexpat1 libfontconfig1 libgcc1 libgconf-2-4 libgdk-pixbuf2.0-0 libglib2.0-0 libgtk-3-0 libnspr4 libpango-1.0-0 libpangocairo-1.0-0 libstdc++6 libx11-6 libx11-xcb1 libxcb1 libxcomposite1 libxcursor1 libxdamage1 libxext6 libxfixes3 libxi6 libxrandr2 libxrender1 libxss1 libxtst6 ca-certificates fonts-liberation libappindicator1 libnss3 lsb-release xdg-utils wget && \ - rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/partial/* + apt-get install -y \ + curl \ + git \ + nano \ + wget \ + ca-certificates \ + fonts-liberation \ + libasound2 \ + libatk1.0-0 \ + libcairo2 \ + libcups2 \ + libdbus-1-3 \ + libexpat1 \ + libfontconfig1 \ + libgdk-pixbuf-2.0-0 \ + libglib2.0-0 \ + libgtk-3-0 \ + libnspr4 \ + libnss3 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libx11-6 \ + libx11-xcb1 \ + libxcb1 \ + libxcomposite1 \ + libxcursor1 \ + libxdamage1 \ + libxext6 \ + libxfixes3 \ + libxi6 \ + libxrandr2 \ + libxrender1 \ + libxss1 \ + libxtst6 \ + lsb-release \ + xdg-utils && \ + rm -rf /var/lib/apt/lists/* WORKDIR /code diff --git a/api/admin.py b/api/admin.py index 2ca2667f..3ce529ea 100644 --- a/api/admin.py +++ b/api/admin.py @@ -59,6 +59,7 @@ class AIModelAdmin(admin.ModelAdmin): list_filter = ( "provider", "model_type", + "domain", "status", "is_public", "is_active", @@ -85,7 +86,7 @@ class AIModelAdmin(admin.ModelAdmin): "Schema", {"fields": ("input_schema", "output_schema"), "classes": ("collapse",)}, ), - ("Metadata", {"fields": ("tags", "metadata"), "classes": ("collapse",)}), + ("Metadata", {"fields": ("tags", "domain", "metadata"), "classes": ("collapse",)}), ("Status & Visibility", {"fields": ("status", "is_public", "is_active")}), ( "Performance Metrics", diff --git a/api/middleware/rate_limit.py b/api/middleware/rate_limit.py index 83cd6050..e53bf83c 100644 --- a/api/middleware/rate_limit.py +++ b/api/middleware/rate_limit.py @@ -1,7 +1,9 @@ import logging +import os import time from typing import Any, Callable, Optional, cast +from django.conf import settings from django.core.cache import cache from django.http import HttpRequest, HttpResponse from redis.exceptions import RedisError @@ -9,12 +11,32 @@ logger = logging.getLogger(__name__) +def get_whitelisted_ips() -> set[str]: + """Get the set of whitelisted IPs from settings or environment variable. + + Can be configured via: + - RATE_LIMIT_WHITELIST_IPS in Django settings (list) + - RATE_LIMIT_WHITELIST_IPS environment variable (comma-separated string) + """ + # First check Django settings + whitelist = getattr(settings, "RATE_LIMIT_WHITELIST_IPS", None) + if whitelist: + return set(whitelist) + + # Fall back to environment variable + env_whitelist = os.environ.get("RATE_LIMIT_WHITELIST_IPS", "") + if env_whitelist: + return {ip.strip() for ip in env_whitelist.split(",") if ip.strip()} + + return set() + + class HttpResponseTooManyRequests(HttpResponse): status_code = 429 def rate_limit_middleware( - get_response: Callable[[HttpRequest], HttpResponse] + get_response: Callable[[HttpRequest], HttpResponse], ) -> Callable[[HttpRequest], HttpResponse]: """Rate limiting middleware that uses a simple cache-based counter.""" @@ -78,10 +100,18 @@ def check_rate_limit(request: HttpRequest) -> bool: return True # Allow request on unexpected error def middleware(request: HttpRequest) -> HttpResponse: + client_ip = get_client_ip(request) + whitelisted_ips = get_whitelisted_ips() + + # Skip rate limiting for whitelisted IPs + if client_ip in whitelisted_ips: + logger.debug(f"Skipping rate limit for whitelisted IP: {client_ip}") + return get_response(request) + if not check_rate_limit(request): logger.warning( f"Rate limited - Method: {request.method}, " - f"Path: {request.path}, IP: {get_client_ip(request)}" + f"Path: {request.path}, IP: {client_ip}" ) return HttpResponseTooManyRequests() diff --git a/api/models/AIModel.py b/api/models/AIModel.py index a34641d4..9646e627 100644 --- a/api/models/AIModel.py +++ b/api/models/AIModel.py @@ -12,6 +12,7 @@ EndpointAuthType, EndpointHTTPMethod, HFModelClass, + PromptDomain, ) User = get_user_model() @@ -82,8 +83,6 @@ class AIModel(models.Model): blank=True, related_name="ai_models", ) - # API Configuration - # Endpoints are stored in separate ModelEndpoint table for flexibility # Model Capabilities supports_streaming = models.BooleanField(default=False) @@ -92,19 +91,34 @@ class AIModel(models.Model): ) supported_languages = models.JSONField( default=list, + blank=True, + null=True, help_text="List of supported language codes (e.g., ['en', 'es', 'fr'])", ) # Input/Output Schema - input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters") - output_schema = models.JSONField(default=dict, help_text="Expected output format") + input_schema = models.JSONField( + default=dict, blank=True, null=True, help_text="Expected input format and parameters" + ) + output_schema = models.JSONField( + default=dict, blank=True, null=True, help_text="Expected output format" + ) # Metadata tags = models.ManyToManyField("api.Tag", blank=True) sectors = models.ManyToManyField("api.Sector", blank=True, related_name="ai_models") geographies = models.ManyToManyField("api.Geography", blank=True, related_name="ai_models") + domain = models.CharField( + max_length=200, + choices=PromptDomain.choices, + blank=True, + null=True, + help_text="Domain or category (e.g., healthcare, education, legal)", + ) metadata = models.JSONField( default=dict, + blank=True, + null=True, help_text="Additional metadata (training data info, limitations, etc.)", ) diff --git a/api/models/AIModelVersion.py b/api/models/AIModelVersion.py new file mode 100644 index 00000000..e9b108d6 --- /dev/null +++ b/api/models/AIModelVersion.py @@ -0,0 +1,286 @@ +"""AI Model Version model for version-specific configuration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.db import models + +from api.utils.enums import AIModelLifecycleStage, AIModelStatus + +if TYPE_CHECKING: + from django.db.models import QuerySet + + +class AIModelVersion(models.Model): + """ + Version of an AI Model with its own configuration. + Each version can have multiple providers. + """ + + ai_model = models.ForeignKey( + "api.AIModel", + on_delete=models.CASCADE, + related_name="versions", + ) + version = models.CharField(max_length=50, help_text="Version number (e.g., 1.0.0)") + version_notes = models.TextField(blank=True, help_text="Changelog/notes for this version") + + # Version-specific capabilities + supports_streaming = models.BooleanField(default=False) + max_tokens = models.IntegerField(null=True, blank=True, help_text="Maximum tokens supported") + supported_languages = models.JSONField( + default=list, help_text="List of supported language codes" + ) + input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters") + output_schema = models.JSONField(default=dict, help_text="Expected output format") + metadata = models.JSONField(default=dict, help_text="Additional version-specific metadata") + + # Status & Lifecycle + status = models.CharField( + max_length=20, + choices=AIModelStatus.choices, + default=AIModelStatus.REGISTERED, + ) + lifecycle_stage = models.CharField( + max_length=20, + choices=AIModelLifecycleStage.choices, + default=AIModelLifecycleStage.DEVELOPMENT, + help_text="Current lifecycle stage of this version", + ) + is_latest = models.BooleanField(default=False, help_text="Whether this is the latest version") + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + published_at = models.DateTimeField(null=True, blank=True) + + class Meta: + unique_together = ["ai_model", "version"] + ordering = ["-created_at"] + verbose_name = "AI Model Version" + verbose_name_plural = "AI Model Versions" + + def __str__(self): + return f"{self.ai_model.name} v{self.version}" + + def save(self, *args, **kwargs): + # If this is set as latest, unset others + if self.is_latest: + AIModelVersion.objects.filter(ai_model=self.ai_model, is_latest=True).exclude( + pk=self.pk + ).update(is_latest=False) + super().save(*args, **kwargs) + + def copy_providers_from(self, source_version: AIModelVersion) -> None: + """ + Copy all providers from another version. + Used when creating a new version. + """ + for provider in source_version.providers.all(): # type: ignore[attr-defined] + # Create a copy of the provider + VersionProvider.objects.create( + version=self, + # Provider info + provider=provider.provider, # type: ignore[attr-defined] + provider_model_id=provider.provider_model_id, # type: ignore[attr-defined] + is_primary=provider.is_primary, # type: ignore[attr-defined] + is_active=provider.is_active, # type: ignore[attr-defined] + # API Endpoint Configuration + api_endpoint_url=provider.api_endpoint_url, # type: ignore[attr-defined] + api_http_method=provider.api_http_method, # type: ignore[attr-defined] + api_timeout_seconds=provider.api_timeout_seconds, # type: ignore[attr-defined] + # Authentication Configuration + api_auth_type=provider.api_auth_type, # type: ignore[attr-defined] + api_auth_header_name=provider.api_auth_header_name, # type: ignore[attr-defined] + api_key=provider.api_key, # type: ignore[attr-defined] + api_key_prefix=provider.api_key_prefix, # type: ignore[attr-defined] + # Request/Response Configuration + api_headers=provider.api_headers, # type: ignore[attr-defined] + api_request_template=provider.api_request_template, # type: ignore[attr-defined] + api_response_path=provider.api_response_path, # type: ignore[attr-defined] + # HuggingFace fields + hf_use_pipeline=provider.hf_use_pipeline, # type: ignore[attr-defined] + hf_auth_token=provider.hf_auth_token, # type: ignore[attr-defined] + hf_model_class=provider.hf_model_class, # type: ignore[attr-defined] + hf_attn_implementation=provider.hf_attn_implementation, # type: ignore[attr-defined] + hf_trust_remote_code=provider.hf_trust_remote_code, # type: ignore[attr-defined] + hf_torch_dtype=provider.hf_torch_dtype, # type: ignore[attr-defined] + hf_device_map=provider.hf_device_map, # type: ignore[attr-defined] + framework=provider.framework, # type: ignore[attr-defined] + # Additional config + config=provider.config, # type: ignore[attr-defined] + ) + + +class VersionProvider(models.Model): + """ + Provider configuration for a specific version. + A version can have multiple providers (HF, Custom, OpenAI, etc.) + Only ONE can be primary per version. + """ + + from api.utils.enums import ( + AIModelFramework, + AIModelProvider, + EndpointAuthType, + EndpointHTTPMethod, + HFModelClass, + ) + + version = models.ForeignKey( + AIModelVersion, + on_delete=models.CASCADE, + related_name="providers", + ) + + # Provider info + provider = models.CharField(max_length=50, choices=AIModelProvider.choices) + provider_model_id = models.CharField( + max_length=255, + blank=True, + help_text="Provider's model identifier (e.g., gpt-4, claude-3-opus)", + ) + is_primary = models.BooleanField( + default=False, help_text="Whether this is the primary provider for the version" + ) + is_active = models.BooleanField(default=True) + + # ============================================ + # API Endpoint Configuration (for API-based providers) + # ============================================ + api_endpoint_url = models.URLField( + max_length=500, + blank=True, + null=True, + help_text="API endpoint URL (e.g., https://api.openai.com/v1/chat/completions)", + ) + api_http_method = models.CharField( + max_length=10, + choices=EndpointHTTPMethod.choices, + default=EndpointHTTPMethod.POST, + help_text="HTTP method for API calls", + ) + api_timeout_seconds = models.IntegerField( + default=60, + help_text="Request timeout in seconds", + ) + + # ============================================ + # Authentication Configuration + # ============================================ + api_auth_type = models.CharField( + max_length=20, + choices=EndpointAuthType.choices, + default=EndpointAuthType.BEARER, + help_text="Authentication type for API calls", + ) + api_auth_header_name = models.CharField( + max_length=100, + default="Authorization", + help_text="Header name for authentication (e.g., Authorization, X-API-Key)", + ) + api_key = models.CharField( + max_length=500, + blank=True, + null=True, + help_text="API key or token (encrypted at rest)", + ) + api_key_prefix = models.CharField( + max_length=50, + blank=True, + default="Bearer", + help_text="Prefix for API key (e.g., Bearer, Token)", + ) + + # ============================================ + # Request/Response Configuration + # ============================================ + api_headers = models.JSONField( + default=dict, + blank=True, + help_text="Additional headers to include in requests", + ) + api_request_template = models.JSONField( + default=dict, + blank=True, + help_text="Request body template with placeholders like {input}, {prompt}", + ) + api_response_path = models.CharField( + max_length=255, + blank=True, + help_text="JSON path to extract response (e.g., choices[0].message.content)", + ) + + # ============================================ + # Huggingface-specific fields + # ============================================ + hf_use_pipeline = models.BooleanField(default=False, help_text="Use Pipeline inference API") + hf_auth_token = models.CharField( + max_length=255, + blank=True, + null=True, + help_text="Huggingface Auth Token for gated models", + ) + hf_model_class = models.CharField( + max_length=100, + choices=HFModelClass.choices, + blank=True, + null=True, + help_text="Specify model head to use", + ) + hf_attn_implementation = models.CharField( + max_length=255, + blank=True, + default="flash_attention_2", + help_text="Attention Function", + ) + hf_trust_remote_code = models.BooleanField( + default=True, + help_text="Trust remote code when loading model", + ) + hf_torch_dtype = models.CharField( + max_length=50, + blank=True, + default="auto", + help_text="Torch dtype (auto, float16, float32, bfloat16)", + ) + hf_device_map = models.CharField( + max_length=50, + blank=True, + default="auto", + help_text="Device map for model loading (auto, cpu, cuda)", + ) + framework = models.CharField( + max_length=10, + choices=AIModelFramework.choices, + blank=True, + null=True, + help_text="Framework (PyTorch or TensorFlow)", + ) + + # ============================================ + # Provider-specific configuration (catch-all) + # ============================================ + config = models.JSONField(default=dict, help_text="Additional provider-specific configuration") + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["-is_primary", "-created_at"] + verbose_name = "Version Provider" + verbose_name_plural = "Version Providers" + + def __str__(self): + primary_str = " (Primary)" if self.is_primary else "" + return f"{self.version} - {self.provider}{primary_str}" + + def save(self, *args, **kwargs): + # Ensure only one primary per version + if self.is_primary: + VersionProvider.objects.filter(version=self.version, is_primary=True).exclude( + pk=self.pk + ).update(is_primary=False) + super().save(*args, **kwargs) diff --git a/api/models/Dataset.py b/api/models/Dataset.py index 4981cdd2..664fb813 100644 --- a/api/models/Dataset.py +++ b/api/models/Dataset.py @@ -5,7 +5,12 @@ from django.db.models import Sum from django.utils.text import slugify -from api.utils.enums import DatasetAccessType, DatasetLicense, DatasetStatus +from api.utils.enums import ( + DatasetAccessType, + DatasetLicense, + DatasetStatus, + DatasetType, +) if TYPE_CHECKING: from api.models.DataSpace import DataSpace @@ -59,9 +64,7 @@ class Dataset(models.Model): max_length=50, default=DatasetStatus.DRAFT, choices=DatasetStatus.choices ) sectors = models.ManyToManyField("api.Sector", blank=True, related_name="datasets") - geographies = models.ManyToManyField( - "api.Geography", blank=True, related_name="datasets" - ) + geographies = models.ManyToManyField("api.Geography", blank=True, related_name="datasets") access_type = models.CharField( max_length=50, default=DatasetAccessType.PUBLIC, @@ -72,6 +75,11 @@ class Dataset(models.Model): default=DatasetLicense.CC_BY_4_0_ATTRIBUTION, choices=DatasetLicense.choices, ) + dataset_type = models.CharField( + max_length=50, + default=DatasetType.DATA, + choices=DatasetType.choices, + ) def save(self, *args: Any, **kwargs: Any) -> None: if not self.slug: @@ -138,10 +146,7 @@ def has_charts(self) -> bool: @property def download_count(self) -> int: return ( - self.resources.aggregate(total_downloads=Sum("download_count"))[ - "total_downloads" - ] - or 0 + self.resources.aggregate(total_downloads=Sum("download_count"))["total_downloads"] or 0 ) @property @@ -176,9 +181,7 @@ def trending_score(self) -> float: return float(base_score) * 0.1 # Calculate recency factor (more recent = higher score) - recent_downloads = ( - recent_resources.aggregate(total=Sum("download_count"))["total"] or 0 - ) + recent_downloads = recent_resources.aggregate(total=Sum("download_count"))["total"] or 0 # Calculate trending score: base score + (recent downloads * recency factor) recency_factor = 2.0 # Weight for recent downloads diff --git a/api/models/Geography.py b/api/models/Geography.py index 56de119f..6ba39e4c 100644 --- a/api/models/Geography.py +++ b/api/models/Geography.py @@ -1,16 +1,17 @@ from typing import List, Optional +from django.core.cache import cache from django.db import models from api.utils.enums import GeoTypes +GEOGRAPHY_CACHE_TTL = 60 * 60 * 6 # 6 hours + class Geography(models.Model): id = models.AutoField(primary_key=True) name = models.CharField(max_length=75, unique=True) - code = models.CharField( - max_length=100, null=True, blank=True, unique=False, default="" - ) + code = models.CharField(max_length=100, null=True, blank=True, unique=False, default="") type = models.CharField(max_length=20, choices=GeoTypes.choices) parent_id = models.ForeignKey( "self", on_delete=models.CASCADE, null=True, blank=True, default=None @@ -37,9 +38,7 @@ def get_all_descendant_names(self) -> List[str]: return descendants @classmethod - def get_geography_names_with_descendants( - cls, geography_names: List[str] - ) -> List[str]: + def get_geography_names_with_descendants(cls, geography_names: List[str]) -> List[str]: """ Given a list of geography names, return all names including their descendants. This is a helper method for filtering that expands parent geographies to include children. @@ -50,6 +49,11 @@ def get_geography_names_with_descendants( Returns: List of geography names including all descendants """ + cache_key = f"geo_descendants:{':'.join(sorted(geography_names))}" + cached: Optional[List[str]] = cache.get(cache_key) + if cached is not None: + return cached + all_names = set() for name in geography_names: @@ -60,7 +64,9 @@ def get_geography_names_with_descendants( # If geography doesn't exist, just add the name as-is all_names.add(name) - return list(all_names) + result = list(all_names) + cache.set(cache_key, result, timeout=GEOGRAPHY_CACHE_TTL) + return result class Meta: db_table = "geography" diff --git a/api/models/PromptDataset.py b/api/models/PromptDataset.py new file mode 100644 index 00000000..89cd1d9a --- /dev/null +++ b/api/models/PromptDataset.py @@ -0,0 +1,82 @@ +"""PromptDataset model - extends Dataset with prompt-specific fields.""" + +from django.db import models + +from api.models.Dataset import Dataset +from api.utils.enums import DatasetType, PromptDomain, PromptPurpose, PromptTaskType + + +class PromptDataset(Dataset): + """ + PromptDataset extends Dataset with prompt-specific fields. + + Uses Django multi-table inheritance - PromptDataset has all fields + from Dataset plus additional prompt-specific fields. The parent + Dataset is automatically created and linked via a OneToOne relationship. + + This means PromptDataset: + - Has all Dataset fields (title, description, tags, sectors, etc.) + - Can have DatasetMetadata entries (via inherited relationship) + - Can have Resources (prompt files instead of data files) + - Has additional prompt-specific fields below + """ + + # Prompt task type (e.g., text generation, classification, etc.) + task_type = models.CharField( + max_length=100, + choices=PromptTaskType.choices, + blank=True, + null=True, + ) + + # Target language(s) for the prompts + target_languages = models.JSONField( + blank=True, + null=True, + help_text="List of target languages for the prompts (e.g., ['en', 'hi', 'ta'])", + ) + + # Domain/category of prompts + domain = models.CharField( + max_length=200, + choices=PromptDomain.choices, + blank=True, + null=True, + help_text="Domain or category (e.g., healthcare, education, legal)", + ) + + # Target AI model types + target_model_types = models.JSONField( + blank=True, + null=True, + help_text="List of AI model types these prompts are designed for (e.g., ['GPT', 'LLAMA'])", + ) + + # Evaluation criteria or metrics + evaluation_criteria = models.JSONField( + blank=True, + null=True, + help_text="Criteria or metrics for evaluating prompt effectiveness", + ) + + # Purpose of the prompt dataset + purpose = models.CharField( + max_length=200, + choices=PromptPurpose.choices, + blank=True, + null=True, + help_text="Purpose of the prompt dataset", + ) + + def save(self, *args, **kwargs): + # Ensure dataset_type is always PROMPT for PromptDataset + self.dataset_type = DatasetType.PROMPT + super().save(*args, **kwargs) + + def __str__(self) -> str: + return f"PromptDataset: {self.title}" + + class Meta: + db_table = "prompt_dataset" + verbose_name = "Prompt Dataset" + verbose_name_plural = "Prompt Datasets" diff --git a/api/models/PromptResource.py b/api/models/PromptResource.py new file mode 100644 index 00000000..dc9d442c --- /dev/null +++ b/api/models/PromptResource.py @@ -0,0 +1,69 @@ +"""PromptResource model - extends Resource with prompt-specific fields.""" + +from django.db import models + +from api.models.Resource import Resource +from api.utils.enums import PromptFormat + + +class PromptResource(models.Model): + """ + PromptResource adds prompt-specific metadata to a Resource. + + This is a OneToOne extension of Resource (not multi-table inheritance) + to store prompt file-specific fields like format, system prompt presence, etc. + + """ + + resource = models.OneToOneField( + Resource, + on_delete=models.CASCADE, + primary_key=True, + related_name="prompt_details", + ) + + # Prompt format/template information + prompt_format = models.CharField( + max_length=100, + choices=PromptFormat.choices, + blank=True, + null=True, + help_text="Format of prompts in this file (e.g., instruction, chat, completion)", + ) + + # Whether prompts include system instructions + has_system_prompt = models.BooleanField( + default=False, + help_text="Whether the prompts in this file include system-level instructions", + ) + + # Whether prompts include example responses + has_example_responses = models.BooleanField( + default=False, + help_text="Whether the prompts in this file include example/expected responses", + ) + + # Average prompt length (for filtering/search) + avg_prompt_length = models.IntegerField( + blank=True, + null=True, + help_text="Average character length of prompts in this file", + ) + + # Number of prompts in this file + prompt_count = models.IntegerField( + blank=True, + null=True, + help_text="Total number of prompts in this file", + ) + + created = models.DateTimeField(auto_now_add=True) + modified = models.DateTimeField(auto_now=True) + + def __str__(self) -> str: + return f"PromptResource: {self.resource.name}" + + class Meta: + db_table = "prompt_resource" + verbose_name = "Prompt Resource" + verbose_name_plural = "Prompt Resources" diff --git a/api/models/Resource.py b/api/models/Resource.py index 3f5e5be4..35562f26 100644 --- a/api/models/Resource.py +++ b/api/models/Resource.py @@ -56,7 +56,14 @@ class Resource(models.Model): version = models.CharField(max_length=50, default="v1.0") def save(self, *args: Any, **kwargs: Any) -> None: - self.slug = slugify(self.name) + if not self.slug: + base_slug = slugify(self.name) + slug = base_slug + counter = 1 + while Resource.objects.filter(slug=slug).exclude(pk=self.pk).exists(): + slug = f"{base_slug}-{counter}" + counter += 1 + self.slug = slug super().save(*args, **kwargs) def __str__(self) -> str: @@ -64,9 +71,7 @@ def __str__(self) -> str: class ResourceFileDetails(models.Model): - resource = models.OneToOneField( - Resource, on_delete=models.CASCADE, null=False, blank=False - ) + resource = models.OneToOneField(Resource, on_delete=models.CASCADE, null=False, blank=False) file = models.FileField(upload_to="resources/", max_length=300) size = models.FloatField(blank=True, null=True) created = models.DateTimeField(auto_now_add=True) @@ -84,9 +89,7 @@ class ResourceDataTable(models.Model): """Model to store indexed CSV data for a resource.""" id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - resource = models.OneToOneField( - Resource, on_delete=models.CASCADE, null=False, blank=False - ) + resource = models.OneToOneField(Resource, on_delete=models.CASCADE, null=False, blank=False) table_name = models.CharField(max_length=255, unique=True) created = models.DateTimeField(auto_now_add=True) modified = models.DateTimeField(auto_now=True) @@ -109,9 +112,7 @@ def save(self, *args, **kwargs): class ResourceVersion(models.Model): - resource = models.ForeignKey( - Resource, on_delete=models.CASCADE, related_name="versions" - ) + resource = models.ForeignKey(Resource, on_delete=models.CASCADE, related_name="versions") version_number = models.CharField(max_length=50) commit_hash = models.CharField(max_length=64, null=True) created_at = models.DateTimeField(auto_now_add=True) @@ -167,9 +168,9 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** return # Get the latest version - last_version: Optional[ResourceVersion] = ( - instance.resource.versions.order_by("-created_at").first() - ) + last_version: Optional[ResourceVersion] = instance.resource.versions.order_by( + "-created_at" + ).first() # Handle case when there are no versions yet if last_version is None: @@ -193,9 +194,7 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** # Use DVC to get the previous version try: # Try to checkout the previous version using DVC - rel_path = Path(instance.file.path).relative_to( - settings.DVC_REPO_PATH - ) + rel_path = Path(instance.file.path).relative_to(settings.DVC_REPO_PATH) tag_name = f"{instance.resource.name}-{last_version.version_number}" # Save current file to temp location @@ -245,9 +244,7 @@ def version_resource_with_dvc(sender, instance: ResourceFileDetails, created, ** # Update using DVC dvc_file = dvc.track_resource(instance.file.path, chunked=use_chunked) - message = ( - f"Update resource: {instance.resource.name} to version {new_version}" - ) + message = f"Update resource: {instance.resource.name} to version {new_version}" dvc.commit_version(dvc_file, message) dvc.tag_version(f"{instance.resource.name}-{new_version}") diff --git a/api/models/__init__.py b/api/models/__init__.py index a93f8040..14ac679f 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,5 +1,6 @@ from api.models.AccessModel import AccessModel, AccessModelResource from api.models.AIModel import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Catalog import Catalog from api.models.Collaborative import Collaborative from api.models.CollaborativeMetadata import CollaborativeMetadata @@ -12,6 +13,8 @@ from api.models.Geography import Geography from api.models.Metadata import Metadata from api.models.Organization import Organization +from api.models.PromptDataset import PromptDataset +from api.models.PromptResource import PromptResource from api.models.Resource import ( Resource, ResourceDataTable, diff --git a/api/schema/aimodel_schema.py b/api/schema/aimodel_schema.py index fe6b120b..f2075a22 100644 --- a/api/schema/aimodel_schema.py +++ b/api/schema/aimodel_schema.py @@ -1,6 +1,6 @@ """GraphQL schema for AI Model.""" -# mypy: disable-error-code=union-attr +# mypy: disable-error-code="union-attr,misc" import datetime from typing import List, Optional @@ -12,19 +12,26 @@ from strawberry_django.pagination import OffsetPaginationInput from api.models.AIModel import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Dataset import Tag from api.schema.base_mutation import BaseMutation, MutationResponse from api.schema.extensions import TrackActivity, TrackModelActivity from api.types.type_aimodel import ( AIModelFilter, + AIModelLifecycleStageEnum, AIModelOrder, AIModelProviderEnum, AIModelStatusEnum, AIModelTypeEnum, + AIModelVersionFilter, + AIModelVersionOrder, EndpointAuthTypeEnum, EndpointHTTPMethodEnum, + PromptDomainEnum, TypeAIModel, + TypeAIModelVersion, TypeModelEndpoint, + TypeVersionProvider, ) from api.utils.graphql_telemetry import trace_resolver from authorization.graphql_permissions import IsAuthenticated @@ -38,9 +45,7 @@ def _update_aimodel_tags(model: AIModel, tags: Optional[List[str]]) -> None: return model.tags.clear() for tag in tags: - model.tags.add( - Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0] - ) + model.tags.add(Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0]) model.save() @@ -58,14 +63,14 @@ def _update_aimodel_sectors(model: AIModel, sectors: List[str]) -> None: model.save() -def _update_aimodel_geographies(model: AIModel, geographies: List[str]) -> None: +def _update_aimodel_geographies(model: AIModel, geographies: List[int]) -> None: """Helper function to update geographies for an AI model.""" from api.models import Geography model.geographies.clear() - for geography_name in geographies: + for geography_id in geographies: try: - geography = Geography.objects.get(name__iexact=geography_name) + geography = Geography.objects.get(id=geography_id) model.geographies.add(geography) except Geography.DoesNotExist: pass @@ -76,11 +81,11 @@ def _update_aimodel_geographies(model: AIModel, geographies: List[str]) -> None: class CreateAIModelInput: """Input for creating a new AI Model.""" - name: str - display_name: str - description: str model_type: AIModelTypeEnum provider: AIModelProviderEnum + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None version: Optional[str] = None provider_model_id: Optional[str] = None supports_streaming: bool = False @@ -90,7 +95,8 @@ class CreateAIModelInput: output_schema: Optional[strawberry.scalars.JSON] = None tags: Optional[List[str]] = None sectors: Optional[List[str]] = None - geographies: Optional[List[str]] = None + geographies: Optional[List[int]] = None + domain: Optional[PromptDomainEnum] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: bool = False @@ -114,7 +120,8 @@ class UpdateAIModelInput: output_schema: Optional[strawberry.scalars.JSON] = None tags: Optional[List[str]] = None sectors: Optional[List[str]] = None - geographies: Optional[List[str]] = None + geographies: Optional[List[int]] = None + domain: Optional[PromptDomainEnum] = None metadata: Optional[strawberry.scalars.JSON] = None is_public: Optional[bool] = None is_active: Optional[bool] = None @@ -159,6 +166,121 @@ class UpdateModelEndpointInput: rate_limit_per_minute: Optional[int] = None +@strawberry.input +class CreateAIModelVersionInput: + """Input for creating a new AI Model Version.""" + + model_id: int + version: str + version_notes: Optional[str] = "" + lifecycle_stage: Optional[AIModelLifecycleStageEnum] = None + supports_streaming: bool = False + max_tokens: Optional[int] = None + supported_languages: Optional[List[str]] = None + input_schema: Optional[strawberry.scalars.JSON] = None + output_schema: Optional[strawberry.scalars.JSON] = None + metadata: Optional[strawberry.scalars.JSON] = None + copy_from_version_id: Optional[int] = None + is_latest: Optional[bool] = None + + +@strawberry.input +class UpdateAIModelVersionInput: + """Input for updating an AI Model Version.""" + + id: int + version: Optional[str] = None + version_notes: Optional[str] = None + lifecycle_stage: Optional[AIModelLifecycleStageEnum] = None + supports_streaming: Optional[bool] = None + max_tokens: Optional[int] = None + supported_languages: Optional[List[str]] = None + input_schema: Optional[strawberry.scalars.JSON] = None + output_schema: Optional[strawberry.scalars.JSON] = None + metadata: Optional[strawberry.scalars.JSON] = None + status: Optional[AIModelStatusEnum] = None + is_latest: Optional[bool] = None + + +@strawberry.input +class CreateVersionProviderInput: + """Input for creating a new Version Provider.""" + + version_id: int + provider: AIModelProviderEnum + provider_model_id: Optional[str] = "" + is_primary: bool = False + is_active: bool = True + + # API Endpoint Configuration + api_endpoint_url: Optional[str] = None + api_http_method: Optional[EndpointHTTPMethodEnum] = None + api_timeout_seconds: int = 60 + + # Authentication Configuration + api_auth_type: Optional[EndpointAuthTypeEnum] = None + api_auth_header_name: str = "Authorization" + api_key: Optional[str] = None + api_key_prefix: str = "Bearer" + + # Request/Response Configuration + api_headers: Optional[strawberry.scalars.JSON] = None + api_request_template: Optional[strawberry.scalars.JSON] = None + api_response_path: Optional[str] = None + + # HuggingFace Configuration + hf_use_pipeline: bool = False + hf_auth_token: Optional[str] = None + hf_model_class: Optional[str] = None + hf_attn_implementation: Optional[str] = "flash_attention_2" + hf_trust_remote_code: bool = True + hf_torch_dtype: Optional[str] = "auto" + hf_device_map: Optional[str] = "auto" + framework: Optional[str] = None + + # Additional config + config: Optional[strawberry.scalars.JSON] = None + + +@strawberry.input +class UpdateVersionProviderInput: + """Input for updating a Version Provider.""" + + id: int + provider_model_id: Optional[str] = None + is_primary: Optional[bool] = None + is_active: Optional[bool] = None + + # API Endpoint Configuration + api_endpoint_url: Optional[str] = None + api_http_method: Optional[EndpointHTTPMethodEnum] = None + api_timeout_seconds: Optional[int] = None + + # Authentication Configuration + api_auth_type: Optional[EndpointAuthTypeEnum] = None + api_auth_header_name: Optional[str] = None + api_key: Optional[str] = None + api_key_prefix: Optional[str] = None + + # Request/Response Configuration + api_headers: Optional[strawberry.scalars.JSON] = None + api_request_template: Optional[strawberry.scalars.JSON] = None + api_response_path: Optional[str] = None + + # HuggingFace Configuration + hf_use_pipeline: Optional[bool] = None + hf_auth_token: Optional[str] = None + hf_model_class: Optional[str] = None + hf_attn_implementation: Optional[str] = None + hf_trust_remote_code: Optional[bool] = None + hf_torch_dtype: Optional[str] = None + hf_device_map: Optional[str] = None + framework: Optional[str] = None + + # Additional config + config: Optional[strawberry.scalars.JSON] = None + + @strawberry.type class Query: """Queries for AI Models.""" @@ -189,10 +311,8 @@ def ai_models( if user.is_superuser: queryset = AIModel.objects.all() else: - # For authenticated users, show their models and public models - queryset = AIModel.objects.filter( - user=user - ) | AIModel.objects.filter(is_public=True, is_active=True) + # For authenticated users, show their models + queryset = AIModel.objects.filter(user=user, organization=None) else: # For non-authenticated users, only show public active models queryset = AIModel.objects.filter(is_public=True, is_active=True) @@ -203,10 +323,12 @@ def ai_models( if order is not strawberry.UNSET: queryset = strawberry_django.ordering.apply(order, queryset, info) + queryset = queryset.distinct() + if pagination is not strawberry.UNSET: queryset = strawberry_django.pagination.apply(pagination, queryset) - return TypeAIModel.from_django_list(list(queryset.distinct())) + return TypeAIModel.from_django_list(list(queryset)) @strawberry.field @trace_resolver(name="get_ai_model", attributes={"component": "aimodel"}) @@ -279,9 +401,7 @@ class Mutation: "get_data": lambda result, **kwargs: { "model_id": str(result.id), "model_name": result.name, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, }, ) @@ -292,6 +412,12 @@ def create_ai_model( organization = info.context.context.get("organization") user = info.context.user + # Generate default values if not provided (similar to dataset creation) + timestamp = datetime.datetime.now().strftime("%d %b %Y - %H:%M:%S") + name = input.name or f"untitled-ai-model-{timestamp}" + display_name = input.display_name or f"Untitled AI Model - {timestamp}" + description = input.description or "" + # Prepare supported_languages supported_languages = input.supported_languages or [] @@ -304,10 +430,10 @@ def create_ai_model( try: model = AIModel.objects.create( - name=input.name, - display_name=input.display_name, + name=name, + display_name=display_name, version=input.version or "", - description=input.description, + description=description, model_type=input.model_type, provider=input.provider, provider_model_id=input.provider_model_id or "", @@ -318,6 +444,7 @@ def create_ai_model( supported_languages=supported_languages, input_schema=input_schema, output_schema=output_schema, + domain=input.domain if input.domain else None, metadata=metadata, is_public=input.is_public, status="REGISTERED", @@ -348,9 +475,7 @@ def create_ai_model( "get_data": lambda result, **kwargs: { "model_id": str(result.id), "model_name": result.name, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, }, ) @@ -372,13 +497,9 @@ def update_ai_model( user=user, organization=model.organization ).first() if not org_member or not org_member.role.can_change: - raise DjangoValidationError( - "You don't have permission to update this model." - ) + raise DjangoValidationError("You don't have permission to update this model.") else: - raise DjangoValidationError( - "You don't have permission to update this model." - ) + raise DjangoValidationError("You don't have permission to update this model.") # Update fields if input.name is not None: @@ -401,6 +522,8 @@ def update_ai_model( model.input_schema = input.input_schema if input.output_schema is not None: model.output_schema = input.output_schema + if input.domain is not None: + model.domain = input.domain if input.metadata is not None: model.metadata = input.metadata if input.is_public is not None: @@ -459,13 +582,9 @@ def delete_ai_model(self, info: Info, model_id: int) -> MutationResponse[bool]: user=user, organization=model.organization ).first() if not org_member or not org_member.role.can_delete: - raise DjangoValidationError( - "You don't have permission to delete this model." - ) + raise DjangoValidationError("You don't have permission to delete this model.") else: - raise DjangoValidationError( - "You don't have permission to delete this model." - ) + raise DjangoValidationError("You don't have permission to delete this model.") model.delete() return MutationResponse.success_response(True) @@ -492,9 +611,7 @@ def create_model_endpoint( try: model = AIModel.objects.get(id=input.model_id) except AIModel.DoesNotExist: - raise DjangoValidationError( - f"AI Model with ID {input.model_id} does not exist." - ) + raise DjangoValidationError(f"AI Model with ID {input.model_id} does not exist.") # Check permissions if not user.is_superuser and model.user != user: @@ -513,9 +630,7 @@ def create_model_endpoint( # If this is primary, unset other primary endpoints if input.is_primary: - ModelEndpoint.objects.filter(model=model, is_primary=True).update( - is_primary=False - ) + ModelEndpoint.objects.filter(model=model, is_primary=True).update(is_primary=False) endpoint = ModelEndpoint.objects.create( model=model, @@ -533,9 +648,7 @@ def create_model_endpoint( is_active=input.is_active, ) - return MutationResponse.success_response( - TypeModelEndpoint.from_django(endpoint) - ) + return MutationResponse.success_response(TypeModelEndpoint.from_django(endpoint)) @strawberry.mutation @BaseMutation.mutation( @@ -559,9 +672,7 @@ def update_model_endpoint( try: endpoint = ModelEndpoint.objects.get(id=input.id) except ModelEndpoint.DoesNotExist: - raise DjangoValidationError( - f"Model Endpoint with ID {input.id} does not exist." - ) + raise DjangoValidationError(f"Model Endpoint with ID {input.id} does not exist.") model = endpoint.model @@ -576,9 +687,7 @@ def update_model_endpoint( "You don't have permission to update this endpoint." ) else: - raise DjangoValidationError( - "You don't have permission to update this endpoint." - ) + raise DjangoValidationError("You don't have permission to update this endpoint.") # Update fields if input.url is not None: @@ -612,9 +721,7 @@ def update_model_endpoint( endpoint.is_primary = True endpoint.save() - return MutationResponse.success_response( - TypeModelEndpoint.from_django(endpoint) - ) + return MutationResponse.success_response(TypeModelEndpoint.from_django(endpoint)) @strawberry.mutation @BaseMutation.mutation( @@ -629,18 +736,14 @@ def update_model_endpoint( }, }, ) - def delete_model_endpoint( - self, info: Info, endpoint_id: int - ) -> MutationResponse[bool]: + def delete_model_endpoint(self, info: Info, endpoint_id: int) -> MutationResponse[bool]: """Delete a model endpoint.""" user = info.context.user try: endpoint = ModelEndpoint.objects.get(id=endpoint_id) except ModelEndpoint.DoesNotExist: - raise DjangoValidationError( - f"Model Endpoint with ID {endpoint_id} does not exist." - ) + raise DjangoValidationError(f"Model Endpoint with ID {endpoint_id} does not exist.") model = endpoint.model @@ -654,10 +757,430 @@ def delete_model_endpoint( raise DjangoValidationError( "You don't have permission to delete this endpoint." ) + else: + raise DjangoValidationError("You don't have permission to delete this endpoint.") + + endpoint.delete() + return MutationResponse.success_response(True) + + # ==================== VERSION MUTATIONS ==================== + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="create_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def create_ai_model_version( + self, info: Info, input: CreateAIModelVersionInput + ) -> MutationResponse[TypeAIModelVersion]: + """Create a new AI model version. Optionally copy providers from another version.""" + user = info.context.user + + try: + model = AIModel.objects.get(id=input.model_id) + except AIModel.DoesNotExist: + raise DjangoValidationError(f"AI Model with ID {input.model_id} does not exist.") + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to add versions to this model." + ) else: raise DjangoValidationError( - "You don't have permission to delete this endpoint." + "You don't have permission to add versions to this model." ) - endpoint.delete() + # Create the version + version = AIModelVersion.objects.create( + ai_model=model, + version=input.version, + version_notes=input.version_notes or "", + lifecycle_stage=input.lifecycle_stage.value if input.lifecycle_stage else "DEVELOPMENT", # type: ignore[misc] + supports_streaming=input.supports_streaming, + max_tokens=input.max_tokens, + supported_languages=input.supported_languages or [], + input_schema=input.input_schema or {}, + output_schema=input.output_schema or {}, + metadata=input.metadata or {}, + status="DRAFT", + is_latest=input.is_latest if input.is_latest is not None else True, + ) + + # If copy_from_version_id is provided, copy all providers + if input.copy_from_version_id: + try: + source_version = AIModelVersion.objects.get(id=input.copy_from_version_id) + version.copy_providers_from(source_version) + except AIModelVersion.DoesNotExist: + pass # Silently ignore if source version doesn't exist + + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="update_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def update_ai_model_version( + self, info: Info, input: UpdateAIModelVersionInput + ) -> MutationResponse[TypeAIModelVersion]: + """Update an AI model version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=input.id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {input.id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError("You don't have permission to update this version.") + else: + raise DjangoValidationError("You don't have permission to update this version.") + + # Update fields + if input.version is not None: + version.version = input.version + if input.version_notes is not None: + version.version_notes = input.version_notes + if input.lifecycle_stage is not None: + version.lifecycle_stage = input.lifecycle_stage.value # type: ignore[misc] + if input.supports_streaming is not None: + version.supports_streaming = input.supports_streaming + if input.max_tokens is not None: + version.max_tokens = input.max_tokens + if input.supported_languages is not None: + version.supported_languages = input.supported_languages + if input.input_schema is not None: + version.input_schema = input.input_schema + if input.output_schema is not None: + version.output_schema = input.output_schema + if input.metadata is not None: + version.metadata = input.metadata + if input.status is not None: + version.status = input.status + if input.is_latest is not None: + version.is_latest = input.is_latest + + version.save() + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="delete_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def delete_ai_model_version(self, info: Info, version_id: int) -> MutationResponse[bool]: + """Delete an AI model version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {version_id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_delete: + raise DjangoValidationError("You don't have permission to delete this version.") + else: + raise DjangoValidationError("You don't have permission to delete this version.") + + version.delete() return MutationResponse.success_response(True) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="publish_ai_model_version", + trace_attributes={"component": "aimodel"}, + ) + def publish_ai_model_version( + self, info: Info, version_id: int + ) -> MutationResponse[TypeAIModelVersion]: + """Publish an AI model version and set it as latest.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError(f"AI Model Version with ID {version_id} does not exist.") + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to publish this version." + ) + else: + raise DjangoValidationError("You don't have permission to publish this version.") + + from django.utils import timezone + + version.status = "ACTIVE" + version.is_latest = True + version.published_at = timezone.now() + version.save() + + return MutationResponse.success_response(TypeAIModelVersion.from_django(version)) + + # ==================== PROVIDER MUTATIONS ==================== + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="create_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def create_version_provider( + self, info: Info, input: CreateVersionProviderInput + ) -> MutationResponse[TypeVersionProvider]: + """Create a new provider for a version.""" + user = info.context.user + + try: + version = AIModelVersion.objects.get(id=input.version_id) + except AIModelVersion.DoesNotExist: + raise DjangoValidationError( + f"AI Model Version with ID {input.version_id} does not exist." + ) + + model = version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to add providers to this version." + ) + else: + raise DjangoValidationError( + "You don't have permission to add providers to this version." + ) + + provider = VersionProvider.objects.create( + version=version, + provider=input.provider, + provider_model_id=input.provider_model_id or "", + is_primary=input.is_primary, + is_active=input.is_active, + # API Endpoint Configuration + api_endpoint_url=input.api_endpoint_url, + api_http_method=input.api_http_method or "POST", + api_timeout_seconds=input.api_timeout_seconds, + # Authentication Configuration + api_auth_type=input.api_auth_type or "BEARER", + api_auth_header_name=input.api_auth_header_name, + api_key=input.api_key, + api_key_prefix=input.api_key_prefix, + # Request/Response Configuration + api_headers=input.api_headers or {}, + api_request_template=input.api_request_template or {}, + api_response_path=input.api_response_path or "", + # HuggingFace Configuration + hf_use_pipeline=input.hf_use_pipeline, + hf_auth_token=input.hf_auth_token, + hf_model_class=input.hf_model_class, + hf_attn_implementation=input.hf_attn_implementation or "flash_attention_2", + hf_trust_remote_code=input.hf_trust_remote_code, + hf_torch_dtype=input.hf_torch_dtype or "auto", + hf_device_map=input.hf_device_map or "auto", + framework=input.framework, + # Additional config + config=input.config or {}, + ) + + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="update_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def update_version_provider( + self, info: Info, input: UpdateVersionProviderInput + ) -> MutationResponse[TypeVersionProvider]: + """Update a version provider.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=input.id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {input.id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to update this provider." + ) + else: + raise DjangoValidationError("You don't have permission to update this provider.") + + # Update fields + if input.provider_model_id is not None: + provider.provider_model_id = input.provider_model_id + if input.is_primary is not None: + provider.is_primary = input.is_primary + if input.is_active is not None: + provider.is_active = input.is_active + + # API Endpoint Configuration + if input.api_endpoint_url is not None: + provider.api_endpoint_url = input.api_endpoint_url + if input.api_http_method is not None: + provider.api_http_method = input.api_http_method + if input.api_timeout_seconds is not None: + provider.api_timeout_seconds = input.api_timeout_seconds + + # Authentication Configuration + if input.api_auth_type is not None: + provider.api_auth_type = input.api_auth_type + if input.api_auth_header_name is not None: + provider.api_auth_header_name = input.api_auth_header_name + if input.api_key is not None: + provider.api_key = input.api_key + if input.api_key_prefix is not None: + provider.api_key_prefix = input.api_key_prefix + + # Request/Response Configuration + if input.api_headers is not None: + provider.api_headers = input.api_headers + if input.api_request_template is not None: + provider.api_request_template = input.api_request_template + if input.api_response_path is not None: + provider.api_response_path = input.api_response_path + + # HuggingFace Configuration + if input.hf_use_pipeline is not None: + provider.hf_use_pipeline = input.hf_use_pipeline + if input.hf_auth_token is not None: + provider.hf_auth_token = input.hf_auth_token + if input.hf_model_class is not None: + provider.hf_model_class = input.hf_model_class + if input.hf_attn_implementation is not None: + provider.hf_attn_implementation = input.hf_attn_implementation + if input.hf_trust_remote_code is not None: + provider.hf_trust_remote_code = input.hf_trust_remote_code + if input.hf_torch_dtype is not None: + provider.hf_torch_dtype = input.hf_torch_dtype + if input.hf_device_map is not None: + provider.hf_device_map = input.hf_device_map + if input.framework is not None: + provider.framework = input.framework + + # Additional config + if input.config is not None: + provider.config = input.config + + provider.save() + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="delete_version_provider", + trace_attributes={"component": "aimodel"}, + ) + def delete_version_provider(self, info: Info, provider_id: int) -> MutationResponse[bool]: + """Delete a version provider.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=provider_id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {provider_id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_delete: + raise DjangoValidationError( + "You don't have permission to delete this provider." + ) + else: + raise DjangoValidationError("You don't have permission to delete this provider.") + + provider.delete() + return MutationResponse.success_response(True) + + @strawberry.mutation + @BaseMutation.mutation( + permission_classes=[IsAuthenticated], + trace_name="set_primary_provider", + trace_attributes={"component": "aimodel"}, + ) + def set_primary_provider( + self, info: Info, provider_id: int + ) -> MutationResponse[TypeVersionProvider]: + """Set a provider as the primary provider for its version.""" + user = info.context.user + + try: + provider = VersionProvider.objects.get(id=provider_id) + except VersionProvider.DoesNotExist: + raise DjangoValidationError(f"Version Provider with ID {provider_id} does not exist.") + + model = provider.version.ai_model + + # Check permissions + if not user.is_superuser and model.user != user: + if model.organization: + org_member = OrganizationMembership.objects.filter( + user=user, organization=model.organization + ).first() + if not org_member or not org_member.role.can_change: + raise DjangoValidationError( + "You don't have permission to update this provider." + ) + else: + raise DjangoValidationError("You don't have permission to update this provider.") + + provider.is_primary = True + provider.save() # The save method will unset other primaries + + return MutationResponse.success_response(TypeVersionProvider.from_django(provider)) diff --git a/api/schema/dataset_schema.py b/api/schema/dataset_schema.py index 458aa2a2..e0adc049 100644 --- a/api/schema/dataset_schema.py +++ b/api/schema/dataset_schema.py @@ -15,6 +15,7 @@ Geography, Metadata, Organization, + PromptDataset, Resource, ResourceChartDetails, ResourceChartImage, @@ -31,12 +32,18 @@ from api.schema.extensions import TrackActivity, TrackModelActivity from api.types.type_dataset import DatasetFilter, DatasetOrder, TypeDataset from api.types.type_organization import TypeOrganization +from api.types.type_prompt_metadata import TypePromptDataset, prompt_task_type_enum +from api.types.type_resource import TypeResource from api.types.type_resource_chart import TypeResourceChart from api.types.type_resource_chart_image import TypeResourceChartImage from api.utils.enums import ( DatasetAccessType, DatasetLicense, DatasetStatus, + DatasetType, + PromptFormat, + PromptPurpose, + PromptTaskType, UseCaseStatus, ) from api.utils.graphql_telemetry import trace_resolver @@ -50,6 +57,10 @@ DatasetAccessTypeENUM = strawberry.enum(DatasetAccessType) # type: ignore DatasetLicenseENUM = strawberry.enum(DatasetLicense) # type: ignore +DatasetTypeENUM = strawberry.enum(DatasetType) # type: ignore +PromptTaskTypeENUM = strawberry.enum(PromptTaskType) # type: ignore +PromptFormatENUM = strawberry.enum(PromptFormat) # type: ignore +PromptPurposeENUM = strawberry.enum(PromptPurpose) # type: ignore # Create permission classes dynamically with different operations @@ -104,9 +115,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return False if user.is_superuser: return True - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() if dataset_perm: return dataset_perm.role.can_view org_perm = OrganizationMembership.objects.filter( @@ -154,9 +163,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: if dataset.user and dataset.user == user: return True # Check if user has specific dataset permissions - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() if dataset_perm: return dataset_perm.role.can_view return False @@ -202,6 +209,8 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: update_dataset_input = kwargs.get("update_dataset_input") if not update_dataset_input: update_dataset_input = kwargs.get("update_metadata_input") + if not update_dataset_input: + update_dataset_input = kwargs.get("update_input") if not update_dataset_input or not hasattr(update_dataset_input, "dataset"): return False @@ -232,9 +241,7 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return True # Check dataset-specific permissions - dataset_perm = DatasetPermission.objects.filter( - user=user, dataset=dataset - ).first() + dataset_perm = DatasetPermission.objects.filter(user=user, dataset=dataset).first() return dataset_perm and dataset_perm.role.can_change and dataset.status == DatasetStatus.DRAFT.value # type: ignore except Dataset.DoesNotExist: @@ -256,9 +263,7 @@ class UpdateMetadataInput: sectors: Optional[List[uuid.UUID]] = None geographies: Optional[List[int]] = None access_type: Optional[DatasetAccessTypeENUM] = DatasetAccessTypeENUM.PUBLIC - license: Optional[DatasetLicenseENUM] = ( - DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE - ) + license: Optional[DatasetLicenseENUM] = DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE @strawberry.input @@ -268,9 +273,56 @@ class UpdateDatasetInput: description: Optional[str] = None tags: Optional[List[str]] = None access_type: Optional[DatasetAccessTypeENUM] = DatasetAccessTypeENUM.PUBLIC - license: Optional[DatasetLicenseENUM] = ( - DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE - ) + license: Optional[DatasetLicenseENUM] = DatasetLicenseENUM.CC_BY_SA_4_0_ATTRIBUTION_SHARE_ALIKE + + +@strawberry.input +class CreateDatasetInput: + """Input for creating a new dataset with optional type specification.""" + + dataset_type: Optional[DatasetTypeENUM] = DatasetTypeENUM.DATA + + +@strawberry.input +class PromptMetadataInput: + """Input for prompt-specific metadata.""" + + task_type: Optional[PromptTaskTypeENUM] = None + target_languages: Optional[List[str]] = None + domain: Optional[str] = None + target_model_types: Optional[List[str]] = None + prompt_format: Optional[str] = None + has_system_prompt: Optional[bool] = False + has_example_responses: Optional[bool] = False + avg_prompt_length: Optional[int] = None + prompt_count: Optional[int] = None + evaluation_criteria: Optional[strawberry.scalars.JSON] = None + purpose: Optional[PromptPurposeENUM] = None + + +@strawberry.input +class UpdatePromptMetadataInput: + """Input for updating prompt-specific metadata (dataset-level fields).""" + + dataset: uuid.UUID + task_type: Optional[PromptTaskTypeENUM] = None + target_languages: Optional[List[str]] = None + domain: Optional[str] = None + target_model_types: Optional[List[str]] = None + evaluation_criteria: Optional[strawberry.scalars.JSON] = None + purpose: Optional[PromptPurposeENUM] = None + + +@strawberry.input +class UpdatePromptResourceInput: + """Input for updating prompt-specific resource metadata (file-level fields).""" + + resource: uuid.UUID + prompt_format: Optional[str] = None + has_system_prompt: Optional[bool] = None + has_example_responses: Optional[bool] = None + avg_prompt_length: Optional[int] = None + prompt_count: Optional[int] = None @trace_resolver(name="add_update_dataset_metadata", attributes={"component": "dataset"}) @@ -285,9 +337,7 @@ def _add_update_dataset_metadata( metadata_field = Metadata.objects.get(id=metadata_input_item.id) if not metadata_field.enabled: _delete_existing_metadata(dataset) - raise ValueError( - f"Metadata with ID {metadata_input_item.id} is not enabled." - ) + raise ValueError(f"Metadata with ID {metadata_input_item.id} is not enabled.") ds_metadata = DatasetMetadata( dataset=dataset, metadata_item=metadata_field, @@ -296,9 +346,7 @@ def _add_update_dataset_metadata( ds_metadata.save() except Metadata.DoesNotExist as e: _delete_existing_metadata(dataset) - raise ValueError( - f"Metadata with ID {metadata_input_item.id} does not exist." - ) + raise ValueError(f"Metadata with ID {metadata_input_item.id} does not exist.") @trace_resolver(name="update_dataset_tags", attributes={"component": "dataset"}) @@ -307,9 +355,7 @@ def _update_dataset_tags(dataset: Dataset, tags: Optional[List[str]]) -> None: return dataset.tags.clear() for tag in tags: - dataset.tags.add( - Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0] - ) + dataset.tags.add(Tag.objects.get_or_create(defaults={"value": tag}, value__iexact=tag)[0]) dataset.save() @@ -330,9 +376,7 @@ def _add_update_dataset_sectors(dataset: Dataset, sectors: List[uuid.UUID]) -> N dataset.save() -@trace_resolver( - name="add_update_dataset_geographies", attributes={"component": "dataset"} -) +@trace_resolver(name="add_update_dataset_geographies", attributes={"component": "dataset"}) def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) -> None: """Update geographies for a dataset.""" dataset.geographies.clear() @@ -343,19 +387,16 @@ def _add_update_dataset_geographies(dataset: Dataset, geography_ids: List[int]) @strawberry.type class Query: - @strawberry_django.field( - filters=DatasetFilter, - pagination=True, - order=DatasetOrder, - ) + @strawberry.field @trace_resolver(name="datasets", attributes={"component": "dataset"}) def datasets( self, info: Info, - filters: Optional[DatasetFilter] = strawberry.UNSET, - pagination: Optional[OffsetPaginationInput] = strawberry.UNSET, - order: Optional[DatasetOrder] = strawberry.UNSET, - ) -> List[TypeDataset]: + filters: Optional[DatasetFilter] = None, + pagination: Optional[OffsetPaginationInput] = None, + order: Optional[DatasetOrder] = None, + include_public: Optional[bool] = False, + ) -> list[TypeDataset]: """Get all datasets.""" organization = info.context.context.get("organization") user = info.context.user @@ -386,13 +427,20 @@ def datasets( # For non-authenticated users, return empty queryset queryset = Dataset.objects.none() - if filters is not strawberry.UNSET: + # Include public datasets if requested (BEFORE filters/pagination) + if include_public: + queryset = queryset | Dataset.objects.filter(status=DatasetStatus.PUBLISHED) + + # Apply filters FIRST (before any slicing) + if filters is not None: queryset = strawberry_django.filters.apply(filters, queryset, info) - if order is not strawberry.UNSET: + # Apply ordering SECOND + if order is not None: queryset = strawberry_django.ordering.apply(order, queryset, info) - if pagination is not strawberry.UNSET: + # Apply pagination LAST (this will slice the queryset) + if pagination is not None: queryset = strawberry_django.pagination.apply(pagination, queryset) return TypeDataset.from_django_list(queryset) @@ -424,28 +472,24 @@ def get_chart_data( dataset = Dataset.objects.get(id=dataset_id) # Fetch ResourceChartImage for the dataset chart_images = list( - ResourceChartImage.objects.filter(dataset_id=dataset_id).order_by( - "modified" - ) + ResourceChartImage.objects.filter(dataset_id=dataset_id).order_by("modified") + ) + resource_ids = Resource.objects.filter(dataset_id=dataset_id).values_list( + "id", flat=True ) - resource_ids = Resource.objects.filter( - dataset_id=dataset_id - ).values_list("id", flat=True) except Dataset.DoesNotExist: raise ValueError(f"Dataset with ID {dataset_id} does not exist.") else: organization = info.context.context.get("organization") if organization: chart_images = list( - ResourceChartImage.objects.filter( - dataset__organization=organization - ).order_by("modified") + ResourceChartImage.objects.filter(dataset__organization=organization).order_by( + "modified" + ) ) else: chart_images = list( - ResourceChartImage.objects.filter(dataset__user=user).order_by( - "modified" - ) + ResourceChartImage.objects.filter(dataset__user=user).order_by("modified") ) if organization: resource_ids = Resource.objects.filter( @@ -458,9 +502,7 @@ def get_chart_data( # Fetch ResourceChartDetails based on the related Resources chart_details = list( - ResourceChartDetails.objects.filter(resource_id__in=resource_ids).order_by( - "modified" - ) + ResourceChartDetails.objects.filter(resource_id__in=resource_ids).order_by("modified") ) # Convert to Strawberry types after getting lists @@ -485,26 +527,14 @@ def get_chart_data( def get_publishers(self, info: Info) -> List[Union[TypeOrganization, TypeUser]]: """Get all publishers (both individual publishers and organizations) who have published datasets.""" # Get all published datasets - published_datasets = Dataset.objects.filter( - status=DatasetStatus.PUBLISHED.value - ) - published_ds_organizations = published_datasets.values_list( - "organization_id", flat=True - ) - published_usecases = UseCase.objects.filter( - status=UseCaseStatus.PUBLISHED.value - ) - published_uc_organizations = published_usecases.values_list( - "organization_id", flat=True - ) - published_organizations = set(published_ds_organizations) | set( - published_uc_organizations - ) + published_datasets = Dataset.objects.filter(status=DatasetStatus.PUBLISHED.value) + published_ds_organizations = published_datasets.values_list("organization_id", flat=True) + published_usecases = UseCase.objects.filter(status=UseCaseStatus.PUBLISHED.value) + published_uc_organizations = published_usecases.values_list("organization_id", flat=True) + published_organizations = set(published_ds_organizations) | set(published_uc_organizations) # Get unique organizations that have published datasets - org_publishers = Organization.objects.filter( - id__in=published_organizations - ).distinct() + org_publishers = Organization.objects.filter(id__in=published_organizations).distinct() published_ds_users = published_datasets.values_list("user_id", flat=True) published_uc_users = published_usecases.values_list("user_id", flat=True) @@ -533,27 +563,53 @@ class Mutation: "get_data": lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "dataset_type": result.dataset_type, + "organization": (str(result.organization.id) if result.organization else None), }, }, ) - def add_dataset(self, info: Info) -> MutationResponse[TypeDataset]: + def add_dataset( + self, info: Info, create_input: Optional[CreateDatasetInput] = None + ) -> MutationResponse[TypeDataset]: # Get organization from context organization = info.context.context.get("organization") dataspace = info.context.context.get("dataspace") user = info.context.user - dataset = Dataset.objects.create( - organization=organization, - dataspace=dataspace, - title=f"New dataset {datetime.datetime.now().strftime('%d %b %Y - %H:%M:%S')}", - description="", - user=user, - access_type=DatasetAccessType.PUBLIC, - license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, - ) + # Determine dataset type + dataset_type = DatasetType.DATA + if create_input and create_input.dataset_type: + dataset_type = create_input.dataset_type + + # Create title based on dataset type + type_label = "prompt dataset" if dataset_type == DatasetType.PROMPT else "dataset" + title = f"New {type_label} {datetime.datetime.now().strftime('%d %b %Y - %H:%M:%S')}" + + # Create PromptDataset or regular Dataset based on type + dataset: Dataset + if dataset_type == DatasetType.PROMPT: + dataset = PromptDataset.objects.create( + organization=organization, + dataspace=dataspace, + title=title, + description="", + user=user, + access_type=DatasetAccessType.PUBLIC, + license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, + # dataset_type is set automatically in PromptDataset.save() + ) + else: + dataset = Dataset.objects.create( + organization=organization, + dataspace=dataspace, + title=title, + description="", + user=user, + access_type=DatasetAccessType.PUBLIC, + license=DatasetLicense.CC_BY_4_0_ATTRIBUTION, + dataset_type=dataset_type, + ) + DatasetPermission.objects.create( user=user, dataset=dataset, role=Role.objects.get(name="owner") ) @@ -568,9 +624,7 @@ def add_dataset(self, info: Info) -> MutationResponse[TypeDataset]: "get_data": lambda result, update_metadata_input=None, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": { "metadata": True, "description": bool( @@ -592,9 +646,7 @@ def add_update_dataset_metadata( except Dataset.DoesNotExist as e: raise DjangoValidationError(f"Dataset with ID {dataset_id} does not exist.") if dataset.status != DatasetStatus.DRAFT.value: - raise DjangoValidationError( - f"Dataset with ID {dataset_id} is not in draft status." - ) + raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") if update_metadata_input.description: dataset.description = update_metadata_input.description @@ -621,17 +673,12 @@ def add_update_dataset_metadata( get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": { "title": kwargs.get("update_dataset_input").title is not None, - "description": kwargs.get("update_dataset_input").description - is not None, - "access_type": kwargs.get("update_dataset_input").access_type - is not None, - "license": kwargs.get("update_dataset_input").license - is not None, + "description": kwargs.get("update_dataset_input").description is not None, + "access_type": kwargs.get("update_dataset_input").access_type is not None, + "license": kwargs.get("update_dataset_input").license is not None, "tags": kwargs.get("update_dataset_input").tags is not None, }, }, @@ -642,9 +689,7 @@ def add_update_dataset_metadata( name="update_dataset", attributes={"component": "dataset", "operation": "mutation"}, ) - def update_dataset( - self, info: Info, update_dataset_input: UpdateDatasetInput - ) -> TypeDataset: + def update_dataset(self, info: Info, update_dataset_input: UpdateDatasetInput) -> TypeDataset: dataset_id = update_dataset_input.dataset try: dataset = Dataset.objects.get(id=dataset_id) @@ -666,6 +711,90 @@ def update_dataset( _update_dataset_tags(dataset, update_dataset_input.tags) return TypeDataset.from_django(dataset) + @strawberry.mutation( + permission_classes=[UpdateDatasetPermission], + ) + @trace_resolver( + name="update_prompt_metadata", + attributes={"component": "dataset"}, + ) + def update_prompt_metadata( + self, info: Info, update_input: UpdatePromptMetadataInput + ) -> MutationResponse[TypePromptDataset]: + """Update prompt-specific metadata for a prompt dataset (dataset-level fields).""" + dataset_id = update_input.dataset + + # Get the PromptDataset directly (it's a child of Dataset via multi-table inheritance) + try: + prompt_dataset = PromptDataset.objects.get(dataset_ptr_id=dataset_id) + except PromptDataset.DoesNotExist: + raise DjangoValidationError( + f"Dataset with ID {dataset_id} is not a prompt dataset or does not exist." + ) + + if prompt_dataset.status != DatasetStatus.DRAFT.value: + raise DjangoValidationError(f"Dataset with ID {dataset_id} is not in draft status.") + + # Update dataset-level fields if provided + if update_input.task_type is not None: + prompt_dataset.task_type = update_input.task_type + if update_input.target_languages is not None: + prompt_dataset.target_languages = update_input.target_languages + if update_input.domain is not None: + prompt_dataset.domain = update_input.domain + if update_input.target_model_types is not None: + prompt_dataset.target_model_types = update_input.target_model_types + if update_input.evaluation_criteria is not None: + prompt_dataset.evaluation_criteria = update_input.evaluation_criteria + if update_input.purpose is not None: + prompt_dataset.purpose = update_input.purpose + + prompt_dataset.save() + return MutationResponse.success_response(TypePromptDataset.from_django(prompt_dataset)) + + @strawberry.mutation( + permission_classes=[IsAuthenticated], + ) + @trace_resolver( + name="update_prompt_resource", + attributes={"component": "resource"}, + ) + def update_prompt_resource( + self, info: Info, update_input: UpdatePromptResourceInput + ) -> MutationResponse[TypeResource]: + """Update prompt-specific metadata for a resource (file-level fields).""" + from api.models import PromptResource, Resource + + resource_id = update_input.resource + + # Get the Resource + try: + resource = Resource.objects.get(id=resource_id) + except Resource.DoesNotExist: + raise DjangoValidationError(f"Resource with ID {resource_id} does not exist.") + + # Check if the dataset is in draft status + if resource.dataset.status != DatasetStatus.DRAFT.value: + raise DjangoValidationError(f"Cannot update resource - dataset is not in draft status.") + + # Get or create PromptResource + prompt_resource, created = PromptResource.objects.get_or_create(resource=resource) + + # Update file-level fields if provided + if update_input.prompt_format is not None: + prompt_resource.prompt_format = update_input.prompt_format + if update_input.has_system_prompt is not None: + prompt_resource.has_system_prompt = update_input.has_system_prompt + if update_input.has_example_responses is not None: + prompt_resource.has_example_responses = update_input.has_example_responses + if update_input.avg_prompt_length is not None: + prompt_resource.avg_prompt_length = update_input.avg_prompt_length + if update_input.prompt_count is not None: + prompt_resource.prompt_count = update_input.prompt_count + + prompt_resource.save() + return MutationResponse.success_response(TypeResource.from_django(resource)) + @strawberry_django.mutation( handle_django_errors=True, permission_classes=[PublishDatasetPermission], # type: ignore[list-item] @@ -675,9 +804,7 @@ def update_dataset( get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), }, ) ], @@ -706,9 +833,7 @@ def publish_dataset(self, info: Info, dataset_id: uuid.UUID) -> TypeDataset: get_data=lambda result, **kwargs: { "dataset_id": str(result.id), "dataset_title": result.title, - "organization": ( - str(result.organization.id) if result.organization else None - ), + "organization": (str(result.organization.id) if result.organization else None), "updated_fields": {"status": "DRAFT", "action": "unpublished"}, }, ) diff --git a/api/schema/stats_schema.py b/api/schema/stats_schema.py index 6810a647..4ee8d1f0 100644 --- a/api/schema/stats_schema.py +++ b/api/schema/stats_schema.py @@ -2,6 +2,7 @@ import strawberry import strawberry_django +from django.core.cache import cache from django.db.models import Count, Q from strawberry.types import Info @@ -12,6 +13,9 @@ from api.utils.graphql_telemetry import trace_resolver from authorization.models import User +STATS_CACHE_KEY = "platform_stats" +STATS_CACHE_TTL = 60 * 10 # 10 minutes + @strawberry.type class StatsType: @@ -31,20 +35,20 @@ class Query: @trace_resolver(name="stats", attributes={"component": "stats"}) def stats(self, info: Info) -> StatsType: """Get platform statistics""" + cached = cache.get(STATS_CACHE_KEY) + if cached: + return StatsType(**cached) + # Count total users total_users = User.objects.count() # Count published datasets - total_published_datasets = Dataset.objects.filter( - status=DatasetStatus.PUBLISHED - ).count() + total_published_datasets = Dataset.objects.filter(status=DatasetStatus.PUBLISHED).count() # Count publishers (organizations and individuals who have published datasets) # First, get organizations that have published datasets org_publishers = ( - Organization.objects.filter(datasets__status=DatasetStatus.PUBLISHED) - .distinct() - .count() + Organization.objects.filter(datasets__status=DatasetStatus.PUBLISHED).distinct().count() ) # Then, get individual users who have published datasets @@ -61,15 +65,16 @@ def stats(self, info: Info) -> StatsType: total_publishers = org_publishers + individual_publishers # Count published usecases - total_published_usecases = UseCase.objects.filter( - status=UseCaseStatus.PUBLISHED - ).count() - - return StatsType( - total_users=total_users, - total_published_datasets=total_published_datasets, - total_publishers=total_publishers, - total_organizations=org_publishers, - total_individuals=individual_publishers, - total_published_usecases=total_published_usecases, - ) + total_published_usecases = UseCase.objects.filter(status=UseCaseStatus.PUBLISHED).count() + + stats_data = { + "total_users": total_users, + "total_published_datasets": total_published_datasets, + "total_publishers": total_publishers, + "total_organizations": org_publishers, + "total_individuals": individual_publishers, + "total_published_usecases": total_published_usecases, + } + cache.set(STATS_CACHE_KEY, stats_data, timeout=STATS_CACHE_TTL) + + return StatsType(**stats_data) diff --git a/api/services/model_api_client.py b/api/services/model_api_client.py index 94a39d87..4c84b030 100644 --- a/api/services/model_api_client.py +++ b/api/services/model_api_client.py @@ -1,9 +1,9 @@ """ API Client for making requests to AI model endpoints. Supports various authentication methods and providers. +Works with VersionProvider configuration. """ -import json import time from typing import Any, Dict, Optional @@ -12,52 +12,62 @@ from django.utils import timezone from tenacity import retry, stop_after_attempt, wait_exponential # type: ignore -from api.models import AIModel, ModelAPIKey, ModelEndpoint +from api.models.AIModelVersion import VersionProvider class ModelAPIClient: - """Client for interacting with AI model APIs""" - - def __init__(self, model: AIModel): - self.model = model - self.endpoint = model.get_primary_endpoint() - if not self.endpoint: - raise ValueError(f"No primary endpoint configured for model {model.name}") - - # Get API key if available - self.api_key = self._get_api_key() - - def _get_api_key(self) -> Optional[str]: - """Get the active API key for this model""" - api_key_obj = self.model.api_keys.filter(is_active=True).first() - if api_key_obj: - return api_key_obj.get_key() - return None + """Client for interacting with AI model APIs via VersionProvider configuration.""" + + def __init__(self, provider: VersionProvider): + """ + Initialize the API client with a VersionProvider. + + Args: + provider: VersionProvider instance with API configuration + """ + self.provider = provider + self.version = provider.version + self.model = provider.version.ai_model + + # Validate that we have an API endpoint configured + if not provider.api_endpoint_url: + raise ValueError( + f"No API endpoint URL configured for provider {provider.provider} " + f"on model {self.model.name}" + ) def _build_headers(self) -> Dict[str, str]: - """Build request headers including authentication""" - headers = {"Content-Type": "application/json", **self.endpoint.headers} + """Build request headers including authentication.""" + headers: Dict[str, str] = { + "Content-Type": "application/json", + **(self.provider.api_headers or {}), + } # Add authentication header - if self.api_key and self.endpoint.auth_type != "NONE": - if self.endpoint.auth_type == "BEARER": - headers[self.endpoint.auth_header_name] = f"Bearer {self.api_key}" - elif self.endpoint.auth_type == "API_KEY": - headers[self.endpoint.auth_header_name] = self.api_key - elif self.endpoint.auth_type == "CUSTOM": - # Custom headers should be in endpoint.headers + if self.provider.api_key and self.provider.api_auth_type != "NONE": + auth_value = self.provider.api_key + if self.provider.api_auth_type == "BEARER": + auth_value = f"{self.provider.api_key_prefix} {self.provider.api_key}".strip() + elif self.provider.api_auth_type == "API_KEY": + # Just use the key directly pass + elif self.provider.api_auth_type == "BASIC": + import base64 + + auth_value = f"Basic {base64.b64encode(self.provider.api_key.encode()).decode()}" + + headers[self.provider.api_auth_header_name] = auth_value return headers def _build_request_body( self, input_text: str, parameters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Build request body based on provider and template""" + """Build request body based on provider and template.""" body: Dict[str, Any] - if self.endpoint.request_template: + if self.provider.api_request_template: # Use custom template - template_copy = self.endpoint.request_template.copy() + template_copy = self.provider.api_request_template.copy() # Replace placeholders body = self._replace_placeholders(template_copy, input_text, parameters or {}) else: @@ -71,8 +81,22 @@ def _replace_placeholders(self, template: Dict, input_text: str, parameters: Dic result: Dict[str, Any] = {} for key, value in template.items(): if isinstance(value, str): + # template = { + # "model": "{model_id}", + # "messages": [{"role": "user", "content": "{input}"}] + # "temperature": {temperature}, + # "max_tokens": {max_tokens} + # } + # parameters = { + # "temperature": 0.7, + # "max_tokens": 1000 + # } + # replaced_value = value.replace("{input}", input_text) replaced_value = replaced_value.replace("{prompt}", input_text) + replaced_value = replaced_value.replace( + "{model_id}", self.provider.provider_model_id + ) # Replace parameter placeholders for param_key, param_value in parameters.items(): replaced_value = replaced_value.replace(f"{{{param_key}}}", str(param_value)) @@ -93,75 +117,77 @@ def _replace_placeholders(self, template: Dict, input_text: str, parameters: Dic return result def _get_default_template(self, input_text: str, parameters: Dict) -> Dict[str, Any]: - """Get default request template based on provider""" - provider = self.model.provider.upper() + """Get default request template based on provider.""" + provider_type = self.provider.provider.upper() + model_id = self.provider.provider_model_id + max_tokens = self.version.max_tokens or 1000 - if provider == "OPENAI": + if provider_type == "OPENAI": return { - "model": self.model.provider_model_id or "gpt-3.5-turbo", + "model": model_id or "gpt-3.5-turbo", "messages": [{"role": "user", "content": input_text}], "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } - elif "LLAMA" in provider: + elif "LLAMA" in provider_type: # Llama models - format depends on provider - if "OLLAMA" in provider: + if "OLLAMA" in provider_type: return { - "model": self.model.provider_model_id or "llama2", + "model": model_id or "llama2", "prompt": input_text, "stream": False, "options": { "temperature": parameters.get("temperature", 0.7), - "num_predict": parameters.get("max_tokens", self.model.max_tokens or 1000), + "num_predict": parameters.get("max_tokens", max_tokens), }, } - elif "TOGETHER" in provider: + elif "TOGETHER" in provider_type: return { - "model": self.model.provider_model_id or "togethercomputer/llama-2-7b-chat", + "model": model_id or "togethercomputer/llama-2-7b-chat", "prompt": input_text, "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } - elif "REPLICATE" in provider: + elif "REPLICATE" in provider_type: return { - "version": self.model.provider_model_id, + "version": model_id, "input": { "prompt": input_text, "temperature": parameters.get("temperature", 0.7), - "max_length": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_length": parameters.get("max_tokens", max_tokens), }, } else: # Generic Llama format (OpenAI-compatible) return { - "model": self.model.provider_model_id or "llama-2-7b-chat", + "model": model_id or "llama-2-7b-chat", "messages": [{"role": "user", "content": input_text}], "temperature": parameters.get("temperature", 0.7), - "max_tokens": parameters.get("max_tokens", self.model.max_tokens or 1000), + "max_tokens": parameters.get("max_tokens", max_tokens), } else: # Generic template for custom APIs return {"input": input_text, "parameters": parameters} def _extract_response(self, response_data: Dict) -> str: - """Extract text response from API response""" - if self.endpoint.response_path: + """Extract text response from API response.""" + if self.provider.api_response_path: # Use custom response path - result: Any = self._get_nested_value(response_data, self.endpoint.response_path) + result: Any = self._get_nested_value(response_data, self.provider.api_response_path) return str(result) # Default extraction based on provider - provider = self.model.provider.upper() + provider_type = self.provider.provider.upper() try: - if provider == "OPENAI": + if provider_type == "OPENAI": return str(response_data["choices"][0]["message"]["content"]) - elif "LLAMA" in provider: - if "OLLAMA" in provider: + elif "LLAMA" in provider_type: + if "OLLAMA" in provider_type: return str(response_data["response"]) - elif "TOGETHER" in provider: + elif "TOGETHER" in provider_type: return str(response_data["output"]["choices"][0]["text"]) - elif "REPLICATE" in provider: + elif "REPLICATE" in provider_type: # Replicate returns array of strings output = response_data.get("output", []) return "".join(output) if isinstance(output, list) else str(output) @@ -204,23 +230,23 @@ def _get_nested_value(self, data: Dict, path: str) -> Any: return current - async def _update_endpoint_success(self) -> None: - """Update endpoint statistics on successful call (async-safe)""" + async def _update_provider_success(self) -> None: + """Update provider statistics on successful call (async-safe).""" def _update() -> None: - self.endpoint.total_requests += 1 - self.endpoint.last_success_at = timezone.now() - self.endpoint.save(update_fields=["total_requests", "last_success_at"]) + # Update provider's updated_at timestamp + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() - async def _update_endpoint_failure(self) -> None: - """Update endpoint statistics on failed call (async-safe)""" + async def _update_provider_failure(self) -> None: + """Update provider statistics on failed call (async-safe).""" def _update() -> None: - self.endpoint.failed_requests += 1 - self.endpoint.last_failure_at = timezone.now() - self.endpoint.save(update_fields=["failed_requests", "last_failure_at"]) + # Update provider's updated_at timestamp + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() @@ -228,28 +254,39 @@ def _update() -> None: async def call_async( self, input_text: str, parameters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Make async API call to the model""" + """Make async API call to the model via the provider configuration.""" start_time = time.time() headers = self._build_headers() body = self._build_request_body(input_text, parameters) + endpoint_url: str = self.provider.api_endpoint_url # type: ignore[assignment] + try: - async with httpx.AsyncClient(timeout=self.endpoint.timeout_seconds) as client: - if self.endpoint.http_method == "POST": - response = await client.post(self.endpoint.url, headers=headers, json=body) - elif self.endpoint.http_method == "GET": - response = await client.get(self.endpoint.url, headers=headers, params=body) + async with httpx.AsyncClient(timeout=self.provider.api_timeout_seconds) as client: + if self.provider.api_http_method == "POST": + response = await client.post(endpoint_url, headers=headers, json=body) + elif self.provider.api_http_method == "GET": + response = await client.get(endpoint_url, headers=headers, params=body) else: - raise ValueError(f"Unsupported HTTP method: {self.endpoint.http_method}") + raise ValueError(f"Unsupported HTTP method: {self.provider.api_http_method}") response.raise_for_status() + + # Check if response is JSON + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + raise ValueError( + f"Expected JSON response but got {content_type}. " + f"Check if the API endpoint URL is correct: {endpoint_url}" + ) + response_data = response.json() latency_ms = (time.time() - start_time) * 1000 - # Update endpoint statistics (async-safe) - await self._update_endpoint_success() + # Update provider statistics (async-safe) + await self._update_provider_success() # Extract response text output_text = self._extract_response(response_data) @@ -260,43 +297,98 @@ async def call_async( "raw_response": response_data, "latency_ms": latency_ms, "status_code": response.status_code, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except httpx.HTTPStatusError as e: # Update failure statistics (async-safe) - await self._update_endpoint_failure() + await self._update_provider_failure() return { "success": False, "error": f"HTTP {e.response.status_code}: {e.response.text}", "status_code": e.response.status_code, "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except Exception as e: # Update failure statistics (async-safe) - await self._update_endpoint_failure() + await self._update_provider_failure() return { "success": False, "error": str(e), "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } def call(self, input_text: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Make synchronous API call to the model""" - import asyncio + """Make synchronous API call to the model using httpx sync client.""" + start_time = time.time() + + headers = self._build_headers() + body = self._build_request_body(input_text, parameters) + endpoint_url: str = self.provider.api_endpoint_url # type: ignore[assignment] - # Run async call in sync context try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # If loop is already running, use nest_asyncio - import nest_asyncio # type: ignore + with httpx.Client(timeout=self.provider.api_timeout_seconds) as client: + if self.provider.api_http_method == "POST": + response = client.post(endpoint_url, headers=headers, json=body) + elif self.provider.api_http_method == "GET": + response = client.get(endpoint_url, headers=headers, params=body) + else: + raise ValueError(f"Unsupported HTTP method: {self.provider.api_http_method}") - nest_asyncio.apply() - return loop.run_until_complete(self.call_async(input_text, parameters)) - else: - return asyncio.run(self.call_async(input_text, parameters)) - except RuntimeError: - # No event loop, create one - return asyncio.run(self.call_async(input_text, parameters)) + response.raise_for_status() + + # Check if response is JSON + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + raise ValueError( + f"Expected JSON response but got {content_type}. " + f"Check if the API endpoint URL is correct: {endpoint_url}" + ) + + response_data = response.json() + latency_ms = (time.time() - start_time) * 1000 + + # Update provider timestamp (sync) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) + + # Extract response text + output_text = self._extract_response(response_data) + + return { + "success": True, + "output": output_text, + "raw_response": response_data, + "latency_ms": latency_ms, + "status_code": response.status_code, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } + + except httpx.HTTPStatusError as e: + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) + + return { + "success": False, + "error": f"HTTP {e.response.status_code}: {e.response.text}", + "status_code": e.response.status_code, + "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "latency_ms": (time.time() - start_time) * 1000, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, + } diff --git a/api/services/model_hf_client.py b/api/services/model_hf_client.py index 826d7c73..e28fba97 100644 --- a/api/services/model_hf_client.py +++ b/api/services/model_hf_client.py @@ -1,7 +1,7 @@ """ Hugging Face Client for running local or remote model inference. -Supports both pipeline-based and model-class-based inference, -and integrates with Django model management (AIModel). +Supports both pipeline-based and model-class-based inference. +Works with VersionProvider configuration. """ import time @@ -24,15 +24,36 @@ pipeline, ) -from api.models import AIModel +from api.models.AIModelVersion import VersionProvider class ModelHFClient: - """Client for interacting with Hugging Face models.""" + """Client for interacting with Hugging Face models via VersionProvider configuration.""" - def __init__(self, model: AIModel): - self.model = model + def __init__(self, provider: VersionProvider): + """ + Initialize the HuggingFace client with a VersionProvider. + + Args: + provider: VersionProvider instance with HuggingFace configuration + """ + self.provider = provider + self.version = provider.version + self.model = provider.version.ai_model self.device = self._get_device() + + # Validate provider type + if provider.provider != "HUGGINGFACE": + raise ValueError( + f"ModelHFClient requires HUGGINGFACE provider, got {provider.provider}" + ) + + # Validate model ID is set + if not provider.provider_model_id: + raise ValueError( + f"No provider_model_id configured for HuggingFace provider on model {self.model.name}" + ) + self.model_map = { "AutoModelForCausalLM": AutoModelForCausalLM, "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, @@ -60,37 +81,58 @@ def _get_device(self) -> int: """Select device (0 for GPU if available, else CPU).""" return 0 if torch.cuda.is_available() else -1 + def _get_torch_dtype(self) -> Any: + """Get torch dtype based on provider configuration.""" + dtype_map = { + "auto": "auto", + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + dtype_str = self.provider.hf_torch_dtype or "auto" + if dtype_str == "auto": + return torch.float16 if torch.cuda.is_available() else torch.float32 + return dtype_map.get(dtype_str, torch.float32) + def _load_pipeline(self) -> Any: """Initialize a Hugging Face pipeline.""" + framework = self.provider.framework or "pt" return pipeline( task=self.task_map.get(self.model.model_type, "text-generation"), - model=self.model.provider_model_id, - framework="pt", + model=self.provider.provider_model_id, + framework=framework, device=self.device, - trust_remote_code=True, - use_auth_token=self.model.hf_auth_token or None, + trust_remote_code=self.provider.hf_trust_remote_code, + token=self.provider.hf_auth_token or None, ) def _load_model_and_tokenizer(self) -> Tuple[Any, Any]: """Load model and tokenizer for manual inference.""" tokenizer = AutoTokenizer.from_pretrained( - self.model.provider_model_id, - trust_remote_code=True, - use_auth_token=self.model.hf_auth_token or None, + self.provider.provider_model_id, + trust_remote_code=self.provider.hf_trust_remote_code, + token=self.provider.hf_auth_token or None, ) model_class = self.model_map.get( - self.model.hf_model_class or "AutoModelForCausalLM", AutoModelForCausalLM - ) - model = model_class.from_pretrained( - pretrained_model_name_or_path=self.model.provider_model_id, - trust_remote_code=True, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - use_auth_token=self.model.hf_auth_token or None, - attn_implementation=self.model.hf_attn_implementation, - device_map="auto", + self.provider.hf_model_class or "AutoModelForCausalLM", AutoModelForCausalLM ) + # Build model loading kwargs + model_kwargs: Dict[str, Any] = { + "pretrained_model_name_or_path": self.provider.provider_model_id, + "trust_remote_code": self.provider.hf_trust_remote_code, + "torch_dtype": self._get_torch_dtype(), + "token": self.provider.hf_auth_token or None, + "device_map": self.provider.hf_device_map or "auto", + } + + # Add attention implementation if specified + if self.provider.hf_attn_implementation: + model_kwargs["attn_implementation"] = self.provider.hf_attn_implementation + + model = model_class.from_pretrained(**model_kwargs) + return model, tokenizer def _generate_from_model(self, model: Any, tokenizer: Any, input_text: str) -> str: @@ -109,26 +151,20 @@ def _generate_from_model(self, model: Any, tokenizer: Any, input_text: str) -> s return decoded_text async def _update_success(self) -> None: - """Update endpoint statistics for successful inference.""" + """Update provider statistics for successful inference.""" def _update() -> None: - endpoint = self.model.get_primary_endpoint() - if endpoint: - endpoint.total_requests += 1 - endpoint.last_success_at = timezone.now() - endpoint.save(update_fields=["total_requests", "last_success_at"]) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() async def _update_failure(self) -> None: - """Update endpoint statistics for failed inference.""" + """Update provider statistics for failed inference.""" def _update() -> None: - endpoint = self.model.get_primary_endpoint() - if endpoint: - endpoint.failed_requests += 1 - endpoint.last_failure_at = timezone.now() - endpoint.save(update_fields=["failed_requests", "last_failure_at"]) + self.provider.updated_at = timezone.now() + self.provider.save(update_fields=["updated_at"]) await sync_to_async(_update)() @@ -136,11 +172,11 @@ def _update() -> None: async def call_async(self, input_text: str) -> Dict[str, Any]: """ Run asynchronous inference on the model. - Supports both pipeline and manual model modes. + Supports both pipeline and manual model modes based on provider configuration. """ start_time = time.time() try: - if self.model.hf_use_pipeline: + if self.provider.hf_use_pipeline: pipe = self._load_pipeline() result = pipe(input_text) output_text = result if isinstance(result, str) else str(result) @@ -155,8 +191,8 @@ async def call_async(self, input_text: str) -> Dict[str, Any]: "success": True, "output": output_text, "latency_ms": latency_ms, - "provider": "HuggingFace", - "model_id": self.model.provider_model_id, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } except Exception as e: @@ -165,8 +201,8 @@ async def call_async(self, input_text: str) -> Dict[str, Any]: "success": False, "error": str(e), "latency_ms": (time.time() - start_time) * 1000, - "provider": "HuggingFace", - "model_id": self.model.provider_model_id, + "provider": self.provider.provider, + "model_id": self.provider.provider_model_id, } def call(self, input_text: str) -> Dict[str, Any]: diff --git a/api/signals/__init__.py b/api/signals/__init__.py index 16d9e257..29e37986 100644 --- a/api/signals/__init__.py +++ b/api/signals/__init__.py @@ -1,2 +1,2 @@ # Import signals to register them -from api.signals import dataset_signals, usecase_signals +from api.signals import aimodel_signals, dataset_signals, usecase_signals diff --git a/api/signals/aimodel_signals.py b/api/signals/aimodel_signals.py new file mode 100644 index 00000000..3c658087 --- /dev/null +++ b/api/signals/aimodel_signals.py @@ -0,0 +1,83 @@ +from typing import Any + +import structlog +from django.db.models.signals import post_delete, pre_save +from django.dispatch import receiver + +from api.models.AIModel import AIModel +from api.utils.enums import AIModelStatus +from search.documents import AIModelDocument + +logger = structlog.get_logger(__name__) + +INDEXABLE_STATUSES = {AIModelStatus.ACTIVE, AIModelStatus.APPROVED} + + +def _should_be_indexed(instance: AIModel) -> bool: + """Return True when the AI model should exist in Elasticsearch.""" + return instance.is_public and instance.is_active and instance.status in INDEXABLE_STATUSES + + +@receiver(pre_save, sender=AIModel) +def handle_aimodel_visibility(sender: Any, instance: AIModel, **kwargs: Any) -> None: + """Sync Elasticsearch document whenever publish/visibility fields change.""" + if not instance.pk: + # New objects are handled by django-elasticsearch-dsl signal processor + return + + try: + original = AIModel.objects.get(pk=instance.pk) + except AIModel.DoesNotExist: + return + + was_indexable = _should_be_indexed(original) + is_indexable = _should_be_indexed(instance) + + if was_indexable == is_indexable and is_indexable: + # Still indexable, just refresh contents + action = "update" + elif was_indexable and not is_indexable: + action = "delete" + elif not was_indexable and is_indexable: + action = "add" + else: + # Neither was nor is indexable; nothing to do + return + + try: + document = AIModelDocument.get(id=instance.id, ignore=404) + if action == "delete": + if document: + document.delete() + logger.info("Removed AI model from Elasticsearch index", model_id=instance.id) + else: + if document: + document.update(instance) + else: + AIModelDocument().update(instance) + logger.info( + "Synced AI model to Elasticsearch index", model_id=instance.id, action=action + ) + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Failed to sync AI model search document", + model_id=instance.id, + action=action, + error=str(exc), + ) + + +@receiver(post_delete, sender=AIModel) +def remove_aimodel_document(sender: Any, instance: AIModel, **kwargs: Any) -> None: + """Ensure Elasticsearch document gets deleted when the model is removed.""" + try: + document = AIModelDocument.get(id=instance.id, ignore=404) + if document: + document.delete() + logger.info("Removed deleted AI model from Elasticsearch index", model_id=instance.id) + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Failed to delete AI model search document", + model_id=instance.id, + error=str(exc), + ) diff --git a/api/signals/collaborative_signals.py b/api/signals/collaborative_signals.py new file mode 100644 index 00000000..25560484 --- /dev/null +++ b/api/signals/collaborative_signals.py @@ -0,0 +1,70 @@ +from typing import Any + +import structlog +from django.core.cache import cache +from django.db.models.signals import pre_save +from django.dispatch import receiver + +from api.models.Collaborative import Collaborative +from api.utils.enums import CollaborativeStatus +from search.documents import CollaborativeDocument + +from .dataset_signals import SEARCH_CACHE_VERSION_KEY + +logger = structlog.get_logger(__name__) + + +@receiver(pre_save, sender=Collaborative) +def handle_collaborative_publication(sender: Any, instance: Collaborative, **kwargs: Any) -> None: + """Sync Elasticsearch index when collaborative publication state changes.""" + + try: + if not instance.pk: + return + + original = Collaborative.objects.get(pk=instance.pk) + + status_changing_to_published = ( + original.status != CollaborativeStatus.PUBLISHED + and instance.status == CollaborativeStatus.PUBLISHED + ) + status_changing_from_published = ( + original.status == CollaborativeStatus.PUBLISHED + and instance.status != CollaborativeStatus.PUBLISHED + ) + remains_published = ( + original.status == CollaborativeStatus.PUBLISHED + and instance.status == CollaborativeStatus.PUBLISHED + ) + + if status_changing_to_published or status_changing_from_published: + version = cache.get(SEARCH_CACHE_VERSION_KEY, 0) + cache.set(SEARCH_CACHE_VERSION_KEY, version + 1) + logger.info("Invalidated search cache for collaborative", collaborative_id=instance.id) + + if status_changing_from_published: + document = CollaborativeDocument.get(id=instance.id, ignore=404) + if document: + document.delete() + logger.info( + "Removed collaborative from Elasticsearch index", + collaborative_id=instance.id, + ) + elif status_changing_to_published or remains_published: + document = CollaborativeDocument.get(id=instance.id, ignore=404) + if document: + document.update(instance) + else: + CollaborativeDocument().update(instance) + logger.info( + "Synced collaborative to Elasticsearch index", + collaborative_id=instance.id, + ) + + except Exception as exc: # pragma: no cover - logging only + logger.error( + "Error in collaborative publication signal handler", + collaborative_id=getattr(instance, "id", None), + error=str(exc), + ) + # Avoid raising to prevent save failures diff --git a/api/types/type_aimodel.py b/api/types/type_aimodel.py index 5e6be994..7870d8cb 100644 --- a/api/types/type_aimodel.py +++ b/api/types/type_aimodel.py @@ -11,17 +11,22 @@ from strawberry.types import Info from api.models.AIModel import AIModel, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.types.base_type import BaseType from api.types.type_dataset import TypeTag from api.types.type_geo import TypeGeo from api.types.type_organization import TypeOrganization from api.types.type_sector import TypeSector from api.utils.enums import ( + AIModelFramework, + AIModelLifecycleStage, AIModelProvider, AIModelStatus, AIModelType, EndpointAuthType, EndpointHTTPMethod, + HFModelClass, + PromptDomain, ) from authorization.types import TypeUser @@ -34,6 +39,10 @@ AIModelProviderEnum = strawberry.enum(AIModelProvider) # type: ignore EndpointAuthTypeEnum = strawberry.enum(EndpointAuthType) # type: ignore EndpointHTTPMethodEnum = strawberry.enum(EndpointHTTPMethod) # type: ignore +AIModelFrameworkEnum = strawberry.enum(AIModelFramework) # type: ignore +HFModelClassEnum = strawberry.enum(HFModelClass) # type: ignore +AIModelLifecycleStageEnum = strawberry.enum(AIModelLifecycleStage) # type: ignore +PromptDomainEnum = strawberry.enum(PromptDomain) # type: ignore @strawberry.type @@ -65,9 +74,7 @@ def success_rate(self) -> Optional[float]: """Calculate success rate.""" if self.total_requests == 0: return None - return ( - (self.total_requests - self.failed_requests) / self.total_requests - ) * 100 + return ((self.total_requests - self.failed_requests) / self.total_requests) * 100 @strawberry_django.filter(AIModel) @@ -78,6 +85,7 @@ class AIModelFilter: status: Optional[AIModelStatusEnum] model_type: Optional[AIModelTypeEnum] provider: Optional[AIModelProviderEnum] + domain: Optional[PromptDomainEnum] is_public: Optional[bool] is_active: Optional[bool] @@ -115,10 +123,11 @@ class TypeAIModel(BaseType): user: Optional["TypeUser"] supports_streaming: bool max_tokens: Optional[int] - supported_languages: strawberry.scalars.JSON - input_schema: strawberry.scalars.JSON - output_schema: strawberry.scalars.JSON - metadata: strawberry.scalars.JSON + supported_languages: Optional[strawberry.scalars.JSON] + input_schema: Optional[strawberry.scalars.JSON] + output_schema: Optional[strawberry.scalars.JSON] + domain: Optional[PromptDomainEnum] + metadata: Optional[strawberry.scalars.JSON] status: AIModelStatusEnum is_public: bool is_active: bool @@ -187,3 +196,139 @@ def primary_endpoint(self) -> Optional[TypeModelEndpoint]: if endpoint: return TypeModelEndpoint.from_django(endpoint) return None + + @strawberry.field(description="Get all versions of this AI model.") + def versions(self) -> List["TypeAIModelVersion"]: + """Get all versions of this AI model.""" + try: + django_instance = cast(AIModel, self) + queryset = django_instance.versions.all() + return TypeAIModelVersion.from_django_list(list(queryset)) + except Exception: + return [] + + @strawberry.field(description="Get the latest version of this AI model.") + def latest_version(self) -> Optional["TypeAIModelVersion"]: + """Get the latest version of this AI model.""" + try: + django_instance = cast(AIModel, self) + version = django_instance.versions.filter(is_latest=True).first() + if not version: + version = django_instance.versions.order_by("-created_at").first() + if version: + return TypeAIModelVersion.from_django(version) + return None + except Exception: + return None + + +@strawberry.type +class TypeVersionProvider(BaseType): + """GraphQL type for VersionProvider.""" + + id: int + provider: AIModelProviderEnum + provider_model_id: Optional[str] + is_primary: bool + is_active: bool + + # API Endpoint Configuration + api_endpoint_url: Optional[str] + api_http_method: EndpointHTTPMethodEnum + api_timeout_seconds: int + + # Authentication Configuration + api_auth_type: EndpointAuthTypeEnum + api_auth_header_name: str + api_key: Optional[str] + api_key_prefix: Optional[str] + + # Request/Response Configuration + api_headers: strawberry.scalars.JSON + api_request_template: strawberry.scalars.JSON + api_response_path: Optional[str] + + # HuggingFace Configuration + hf_use_pipeline: bool + hf_auth_token: Optional[str] + hf_model_class: Optional[str] + hf_attn_implementation: Optional[str] + hf_trust_remote_code: bool + hf_torch_dtype: Optional[str] + hf_device_map: Optional[str] + framework: Optional[str] + + # Additional config + config: strawberry.scalars.JSON + created_at: datetime + updated_at: datetime + + +@strawberry_django.filter(AIModelVersion) +class AIModelVersionFilter: + """Filter for AI Model Version.""" + + id: Optional[int] + status: Optional[AIModelStatusEnum] + is_latest: Optional[bool] + + +@strawberry_django.order(AIModelVersion) +class AIModelVersionOrder: + """Order for AI Model Version.""" + + version: strawberry.auto + created_at: strawberry.auto + updated_at: strawberry.auto + + +@strawberry.type +class TypeAIModelVersion(BaseType): + """GraphQL type for AI Model Version.""" + + id: int + version: str + version_notes: Optional[str] + supports_streaming: bool + max_tokens: Optional[int] + supported_languages: Optional[strawberry.scalars.JSON] + input_schema: Optional[strawberry.scalars.JSON] + output_schema: Optional[strawberry.scalars.JSON] + metadata: Optional[strawberry.scalars.JSON] + status: AIModelStatusEnum + lifecycle_stage: AIModelLifecycleStageEnum + is_latest: bool + created_at: datetime + updated_at: datetime + published_at: Optional[datetime] + + @strawberry.field + def providers(self) -> List[TypeVersionProvider]: + """Get all providers for this version.""" + try: + django_instance = cast(AIModelVersion, self) + queryset = django_instance.providers.all() + return TypeVersionProvider.from_django_list(list(queryset)) + except Exception: + return [] + + @strawberry.field + def primary_provider(self) -> Optional[TypeVersionProvider]: + """Get the primary provider for this version.""" + try: + django_instance = cast(AIModelVersion, self) + provider = django_instance.providers.filter(is_primary=True).first() + if provider: + return TypeVersionProvider.from_django(provider) + return None + except Exception: + return None + + @strawberry.field + def ai_model(self) -> Optional[TypeAIModel]: + """Get the parent AI model.""" + try: + django_instance = cast(AIModelVersion, self) + return TypeAIModel.from_django(django_instance.ai_model) + except Exception: + return None diff --git a/api/types/type_dataset.py b/api/types/type_dataset.py index 0934a458..b8e71ec3 100644 --- a/api/types/type_dataset.py +++ b/api/types/type_dataset.py @@ -8,19 +8,20 @@ from strawberry.enum import EnumType from strawberry.types import Info -from api.models import Dataset, DatasetMetadata, Resource, Tag +from api.models import Dataset, DatasetMetadata, PromptDataset, Resource, Tag from api.types.base_type import BaseType from api.types.type_dataset_metadata import TypeDatasetMetadata from api.types.type_geo import TypeGeo from api.types.type_organization import TypeOrganization from api.types.type_resource import TypeResource from api.types.type_sector import TypeSector -from api.utils.enums import DatasetStatus +from api.utils.enums import DatasetStatus, DatasetType, PromptTaskType from authorization.types import TypeUser logger = structlog.get_logger("dataspace.type_dataset") dataset_status: EnumType = strawberry.enum(DatasetStatus) # type: ignore +dataset_type_enum: EnumType = strawberry.enum(DatasetType) # type: ignore @strawberry_django.filter(Dataset) @@ -29,6 +30,7 @@ class DatasetFilter: id: Optional[uuid.UUID] status: Optional[dataset_status] + dataset_type: Optional[dataset_type_enum] @strawberry_django.order(Dataset) @@ -55,6 +57,7 @@ class TypeDataset(BaseType): description: Optional[str] slug: str status: dataset_status + dataset_type: dataset_type_enum organization: Optional["TypeOrganization"] created: datetime modified: datetime @@ -107,6 +110,24 @@ def metadata(self) -> List["TypeDatasetMetadata"]: except (AttributeError, DatasetMetadata.DoesNotExist): return [] + @strawberry.field + def prompt_metadata(self) -> Optional[strawberry.scalars.JSON]: + """Get prompt-specific metadata for this dataset (only for PROMPT type datasets).""" + try: + # Check if this dataset is a PromptDataset (via multi-table inheritance) + prompt_dataset = PromptDataset.objects.filter(dataset_ptr_id=self.id).first() + if prompt_dataset: + return { + "task_type": prompt_dataset.task_type, + "target_languages": prompt_dataset.target_languages, + "domain": prompt_dataset.domain, + "target_model_types": prompt_dataset.target_model_types, + "evaluation_criteria": prompt_dataset.evaluation_criteria, + } + return None + except (AttributeError, PromptDataset.DoesNotExist): + return None + @strawberry.field def resources(self) -> List["TypeResource"]: """Get resources for this dataset.""" @@ -129,9 +150,7 @@ def formats(self: Any) -> List[str]: except (AttributeError, Resource.DoesNotExist): return [] - @strawberry.field( - description="Get similar datasets for this dataset from elasticsearch index." - ) + @strawberry.field(description="Get similar datasets for this dataset from elasticsearch index.") def similar_datasets(self: Any) -> List["TypeDataset"]: # type: ignore """Get similar datasets for this dataset from elasticsearch index.""" try: @@ -182,9 +201,7 @@ def similar_datasets(self: Any) -> List["TypeDataset"]: # type: ignore sectors = [sector.name for sector in dataset.sectors.all().select_related()] # type: ignore if sectors: - should_queries.append( - ESQ("terms", **{"sectors.raw": sectors, "boost": 2.0}) - ) + should_queries.append(ESQ("terms", **{"sectors.raw": sectors, "boost": 2.0})) # Add metadata similarity # Dataset.metadata is the related_name for DatasetMetadata diff --git a/api/types/type_prompt_metadata.py b/api/types/type_prompt_metadata.py new file mode 100644 index 00000000..4754aaf9 --- /dev/null +++ b/api/types/type_prompt_metadata.py @@ -0,0 +1,45 @@ +"""GraphQL type for PromptDataset.""" + +import uuid +from datetime import datetime +from typing import List, Optional + +import strawberry +import strawberry_django +from strawberry.enum import EnumType +from strawberry.types import Info + +from api.models.PromptDataset import PromptDataset +from api.types.base_type import BaseType +from api.types.type_dataset import TypeDataset +from api.utils.enums import ( + PromptDomain, + PromptTaskType, + TargetLanguage, + TargetModelType, +) + +prompt_task_type_enum: EnumType = strawberry.enum(PromptTaskType) # type: ignore +prompt_domain_enum: EnumType = strawberry.enum(PromptDomain) # type: ignore +target_language_enum: EnumType = strawberry.enum(TargetLanguage) # type: ignore +target_model_type_enum: EnumType = strawberry.enum(TargetModelType) # type: ignore + + +@strawberry_django.type( + PromptDataset, + fields="__all__", +) +class TypePromptDataset(TypeDataset): + """ + GraphQL type for PromptDataset. + + Extends TypeDataset with prompt-specific fields. + Inherits all Dataset fields plus adds prompt-specific ones. + """ + + # Dataset-level prompt fields + task_type: Optional[prompt_task_type_enum] + target_languages: Optional[List[target_language_enum]] + domain: Optional[prompt_domain_enum] + target_model_types: Optional[List[target_model_type_enum]] + evaluation_criteria: Optional[strawberry.scalars.JSON] diff --git a/api/types/type_prompt_resource_details.py b/api/types/type_prompt_resource_details.py new file mode 100644 index 00000000..dca66819 --- /dev/null +++ b/api/types/type_prompt_resource_details.py @@ -0,0 +1,22 @@ +"""GraphQL type for prompt resource details.""" + +from typing import Optional + +import strawberry +from strawberry.enum import EnumType + +from api.utils.enums import PromptFormat + +# Create the enum for GraphQL schema +prompt_format_enum: EnumType = strawberry.enum(PromptFormat) # type: ignore + + +@strawberry.type +class TypePromptResourceDetails: + """Prompt-specific fields for a resource/file.""" + + prompt_format: Optional[prompt_format_enum] + has_system_prompt: bool + has_example_responses: bool + avg_prompt_length: Optional[int] + prompt_count: Optional[int] diff --git a/api/types/type_resource.py b/api/types/type_resource.py index 6b541faa..20532023 100644 --- a/api/types/type_resource.py +++ b/api/types/type_resource.py @@ -8,6 +8,7 @@ from strawberry_django import type from api.models import ( + PromptResource, Resource, ResourceFileDetails, ResourceMetadata, @@ -17,6 +18,7 @@ from api.types.base_type import BaseType from api.types.type_file_details import TypeFileDetails from api.types.type_preview_data import PreviewData +from api.types.type_prompt_resource_details import TypePromptResourceDetails from api.types.type_resource_metadata import TypeResourceMetadata from api.utils.data_indexing import get_preview_data, get_row_count from api.utils.graphql_telemetry import trace_resolver @@ -104,9 +106,7 @@ def metadata(self) -> List[TypeResourceMetadata]: # return [] @strawberry.field - @trace_resolver( - name="get_resource_file_details", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_file_details", attributes={"component": "resource"}) def file_details(self) -> Optional[TypeFileDetails]: """Get file details for this resource. @@ -138,9 +138,7 @@ def schema(self) -> List[TypeResourceSchema]: return [] @strawberry.field - @trace_resolver( - name="get_resource_preview_data", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_preview_data", attributes={"component": "resource"}) def preview_data(self) -> PreviewData: """Get preview data for the resource. @@ -151,9 +149,11 @@ def preview_data(self) -> PreviewData: file_details = getattr(self, "resourcefiledetails", None) if not file_details or not getattr(self, "preview_details", None): return PreviewData(columns=[], rows=[]) - if not getattr( - self, "preview_enabled", False - ) or not file_details.format.lower() in ["csv", "xls", "xlsx"]: + if not getattr(self, "preview_enabled", False) or not file_details.format.lower() in [ + "csv", + "xls", + "xlsx", + ]: return PreviewData(columns=[], rows=[]) try: @@ -169,9 +169,7 @@ def preview_data(self) -> PreviewData: return PreviewData(columns=[], rows=[]) @strawberry.field - @trace_resolver( - name="get_resource_no_of_entries", attributes={"component": "resource"} - ) + @trace_resolver(name="get_resource_no_of_entries", attributes={"component": "resource"}) def no_of_entries(self) -> int: """Get the number of entries in the resource.""" try: @@ -179,10 +177,7 @@ def no_of_entries(self) -> int: if not file_details: return 0 - if ( - not hasattr(file_details, "format") - or file_details.format.lower() != "csv" - ): + if not hasattr(file_details, "format") or file_details.format.lower() != "csv": return 0 try: @@ -193,3 +188,25 @@ def no_of_entries(self) -> int: except Exception as e: logger.error(f"Error getting number of entries: {str(e)}") return 0 + + @strawberry.field + @trace_resolver(name="get_prompt_details", attributes={"component": "resource"}) + def prompt_details(self) -> Optional[TypePromptResourceDetails]: + """Get prompt-specific details for this resource (only for prompt datasets). + + Returns: + Optional[TypePromptResourceDetails]: Prompt details if they exist, None otherwise + """ + try: + prompt_resource = PromptResource.objects.filter(resource_id=self.id).first() + if prompt_resource: + return TypePromptResourceDetails( + prompt_format=prompt_resource.prompt_format, + has_system_prompt=prompt_resource.has_system_prompt, + has_example_responses=prompt_resource.has_example_responses, + avg_prompt_length=prompt_resource.avg_prompt_length, + prompt_count=prompt_resource.prompt_count, + ) + return None + except (AttributeError, PromptResource.DoesNotExist): + return None diff --git a/api/types/type_sector.py b/api/types/type_sector.py index 40e2d893..c4f801bd 100644 --- a/api/types/type_sector.py +++ b/api/types/type_sector.py @@ -9,7 +9,7 @@ from api.models import Sector from api.types.base_type import BaseType -from api.utils.enums import DatasetStatus +from api.utils.enums import AIModelStatus, DatasetStatus @strawberry.enum @@ -54,6 +54,24 @@ def min_dataset_count(self, queryset: Any, value: Optional[int], prefix: str) -> # Return queryset with filter return queryset, Q(**{f"{prefix}_dataset_count__gte": value}) + @strawberry_django.filter_field + def min_aimodel_count(self, queryset: Any, value: Optional[int], prefix: str) -> tuple[Any, Q]: # type: ignore + # Skip filtering if no value provided + if value is None: + return queryset, Q() + + # Annotate queryset with dataset count + queryset = queryset.annotate( + _aimodel_count=Count( + "ai_models", + filter=Q(ai_models__status=AIModelStatus.ACTIVE), + distinct=True, + ) + ) + + # Return queryset with filter + return queryset, Q(**{f"{prefix}_aimodel_count__gte": value}) + @strawberry_django.filter_field def max_dataset_count(self, queryset: Any, value: Optional[int], prefix: str) -> tuple[Any, Q]: # type: ignore # Skip filtering if no value provided @@ -110,3 +128,7 @@ class TypeSector(BaseType): @strawberry.field def dataset_count(self: Any) -> int: return int(self.datasets.filter(status=DatasetStatus.PUBLISHED).count()) + + @strawberry.field + def aimodel_count(self: Any) -> int: + return int(self.ai_models.filter(status=AIModelStatus.ACTIVE).count()) diff --git a/api/urls.py b/api/urls.py index 57f5c30c..6d67c407 100644 --- a/api/urls.py +++ b/api/urls.py @@ -9,11 +9,14 @@ from api.views import ( aimodel_detail, aimodel_execution, + auditor, auth, download, generate_dynamic_chart, search_aimodel, + search_collaborative, search_dataset, + search_publisher, search_unified, search_usecase, trending_datasets, @@ -32,23 +35,40 @@ path("auth/keycloak/login/", auth.KeycloakLoginView.as_view(), name="keycloak_login"), path("auth/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"), path("auth/user/info/", auth.UserInfoView.as_view(), name="user_info"), + # Auditor management endpoints + path( + "organizations//auditors/", + auditor.OrganizationAuditorsView.as_view(), + name="organization_auditors", + ), + path( + "users/search-by-email/", + auditor.SearchUserByEmailView.as_view(), + name="search_user_by_email", + ), # API endpoints path("search/dataset/", search_dataset.SearchDataset.as_view(), name="search_dataset"), path("search/usecase/", search_usecase.SearchUseCase.as_view(), name="search_usecase"), path("search/aimodel/", search_aimodel.SearchAIModel.as_view(), name="search_aimodel"), + path( + "search/collaborative/", + search_collaborative.SearchCollaborative.as_view(), + name="search_collaborative", + ), path("search/unified/", search_unified.UnifiedSearch.as_view(), name="search_unified"), + path("search/publisher/", search_publisher.SearchPublisher.as_view(), name="search_publisher"), path( - "aimodels//", + "aimodels//", aimodel_detail.AIModelDetailView.as_view(), name="aimodel_detail", ), path( - "aimodels//call/", + "aimodels//call/", aimodel_execution.call_aimodel, name="aimodel_call", ), path( - "aimodels//call-async/", + "aimodels//call-async/", aimodel_execution.call_aimodel_async, name="aimodel_call_async", ), diff --git a/api/utils/data_indexing.py b/api/utils/data_indexing.py index 7933d684..ce30f0ca 100644 --- a/api/utils/data_indexing.py +++ b/api/utils/data_indexing.py @@ -30,9 +30,7 @@ def get_sql_type(pandas_dtype: str) -> str: return "TEXT" -def create_table_for_resource( - resource: Resource, df: pd.DataFrame -) -> Optional[ResourceDataTable]: +def create_table_for_resource(resource: Resource, df: pd.DataFrame) -> Optional[ResourceDataTable]: """Create a database table for the resource data and index it.""" try: # Create ResourceDataTable entry first to get the table name @@ -65,9 +63,7 @@ def create_table_for_resource( df.to_csv(csv_data, index=False, header=False) csv_data.seek(0) - copy_sql = ( - f'COPY "{temp_table}" ({",".join(quoted_columns)}) FROM STDIN WITH CSV' - ) + copy_sql = f'COPY "{temp_table}" ({",".join(quoted_columns)}) FROM STDIN WITH CSV' cursor.copy_expert(copy_sql, csv_data) # Insert from temp to main table with validation @@ -102,14 +98,10 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: try: file_details = resource.resourcefiledetails if not file_details: - logger.info( - f"Resource {resource_id} has no file details, skipping indexing" - ) + logger.info(f"Resource {resource_id} has no file details, skipping indexing") return None except Exception as e: - logger.error( - f"Failed to access file details for resource {resource_id}: {str(e)}" - ) + logger.error(f"Failed to access file details for resource {resource_id}: {str(e)}") return None # Check file format @@ -131,9 +123,7 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: ) return None except Exception as e: - logger.error( - f"Failed to determine format for resource {resource_id}: {str(e)}" - ) + logger.error(f"Failed to determine format for resource {resource_id}: {str(e)}") return None # Load tabular data with timeout protection @@ -144,9 +134,7 @@ def index_resource_data(resource: Resource) -> Optional[ResourceDataTable]: @contextmanager def timeout(seconds: int) -> Generator[None, None, None]: def handler(signum: int, frame: Any) -> None: - raise TimeoutError( - f"Loading data timed out after {seconds} seconds" - ) + raise TimeoutError(f"Loading data timed out after {seconds} seconds") # Set the timeout handler original_handler = signal.getsignal(signal.SIGALRM) @@ -163,9 +151,7 @@ def handler(signum: int, frame: Any) -> None: with timeout(60): # 60 second timeout for loading data df = load_tabular_data(file_details.file.path, format) except TimeoutError as te: - logger.error( - f"Timeout while loading data for resource {resource_id}: {str(te)}" - ) + logger.error(f"Timeout while loading data for resource {resource_id}: {str(te)}") return None except Exception: # Fallback without timeout if signal.SIGALRM is not available (e.g., on Windows) @@ -204,9 +190,7 @@ def handler(signum: int, frame: Any) -> None: # Rename all but the first occurrence for i, idx in enumerate(indices[1:], 1): df.columns.values[idx] = f"{col}_{i}" - logger.warning( - f"Renamed duplicate columns in resource {resource_id}" - ) + logger.warning(f"Renamed duplicate columns in resource {resource_id}") except Exception as e: logger.error( f"Failed to sanitize column names for resource {resource_id}: {str(e)}" @@ -229,9 +213,7 @@ def handler(signum: int, frame: Any) -> None: existing_table = ResourceDataTable.objects.get(resource=resource) try: with connections[DATA_DB].cursor() as cursor: - cursor.execute( - f'DROP TABLE IF EXISTS "{existing_table.table_name}"' - ) + cursor.execute(f'DROP TABLE IF EXISTS "{existing_table.table_name}"') except Exception as drop_error: logger.error( f"Failed to drop existing table for resource {resource_id}: {str(drop_error)}" @@ -292,15 +274,11 @@ def handler(signum: int, frame: Any) -> None: # For description, preserve existing if available, otherwise auto-generate description = f"Description of column {col}" if col in existing_schemas: - existing_description = existing_schemas[col][ - "description" - ] + existing_description = existing_schemas[col]["description"] # Check for None and non-auto-generated descriptions if existing_description is not None: description = existing_description - logger.debug( - f"Preserved custom description for column {col}" - ) + logger.debug(f"Preserved custom description for column {col}") # Create the schema entry ResourceSchema.objects.create( @@ -393,9 +371,7 @@ def get_row_count(resource: Resource) -> int: import traceback error_tb = traceback.format_exc() - logger.error( - f"Error getting row count for resource {resource.id}:\n{str(e)}\n{error_tb}" - ) + logger.error(f"Error getting row count for resource {resource.id}:\n{str(e)}\n{error_tb}") return 0 @@ -429,9 +405,7 @@ def get_preview_data(resource: Resource) -> Optional[PreviewData]: try: if is_all_entries: # For safety, always limit the number of rows returned even for 'all entries' - cursor.execute( - f'SELECT * FROM "{data_table.table_name}" LIMIT 1000' - ) + cursor.execute(f'SELECT * FROM "{data_table.table_name}" LIMIT 1000') else: # Ensure we have valid integer values for the calculation start = int(start_entry) if start_entry is not None else 0 @@ -443,8 +417,8 @@ def get_preview_data(resource: Resource) -> Optional[PreviewData]: columns = [desc[0] for desc in cursor.description] data = cursor.fetchall() - # Convert tuples to lists - rows = [list(row) for row in data] + # Convert tuples to lists and sanitize None values to empty strings + rows = [[cell if cell is not None else "" for cell in row] for row in data] return PreviewData(columns=columns, rows=rows) except Exception as query_error: logger.error( diff --git a/api/utils/enums.py b/api/utils/enums.py index 18ccbabe..826cd77e 100644 --- a/api/utils/enums.py +++ b/api/utils/enums.py @@ -86,6 +86,92 @@ class DatasetStatus(models.TextChoices): ARCHIVED = "ARCHIVED" +class DatasetType(models.TextChoices): + DATA = "DATA" + PROMPT = "PROMPT" + + +class PromptTaskType(models.TextChoices): + TEXT_GENERATION = "TEXT_GENERATION" + TEXT_CLASSIFICATION = "TEXT_CLASSIFICATION" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + SUMMARIZATION = "SUMMARIZATION" + TRANSLATION = "TRANSLATION" + SENTIMENT_ANALYSIS = "SENTIMENT_ANALYSIS" + NAMED_ENTITY_RECOGNITION = "NAMED_ENTITY_RECOGNITION" + CODE_GENERATION = "CODE_GENERATION" + CONVERSATION = "CONVERSATION" + INSTRUCTION_FOLLOWING = "INSTRUCTION_FOLLOWING" + REASONING = "REASONING" + CREATIVE_WRITING = "CREATIVE_WRITING" + DATA_EXTRACTION = "DATA_EXTRACTION" + OTHER = "OTHER" + + +class PromptPurpose(models.TextChoices): + RESEARCH = "RESEARCH" + EDUCATION = "EDUCATION" + EVALUATION = "EVALUATION" + OTHER = "OTHER" + + +class PromptDomain(models.TextChoices): + HEALTHCARE = "HEALTHCARE" + EDUCATION = "EDUCATION" + LEGAL = "LEGAL" + FINANCE = "FINANCE" + AGRICULTURE = "AGRICULTURE" + ENVIRONMENT = "ENVIRONMENT" + GOVERNMENT = "GOVERNMENT" + TECHNOLOGY = "TECHNOLOGY" + SCIENCE = "SCIENCE" + SOCIAL_SERVICES = "SOCIAL_SERVICES" + TRANSPORTATION = "TRANSPORTATION" + ENERGY = "ENERGY" + GENERAL = "GENERAL" + OTHER = "OTHER" + + +class PromptFormat(models.TextChoices): + INSTRUCTION = "INSTRUCTION" + CHAT = "CHAT" + COMPLETION = "COMPLETION" + FEW_SHOT = "FEW_SHOT" + CHAIN_OF_THOUGHT = "CHAIN_OF_THOUGHT" + ZERO_SHOT = "ZERO_SHOT" + OTHER = "OTHER" + + +class TargetLanguage(models.TextChoices): + ENGLISH = "ENGLISH" + HINDI = "HINDI" + TAMIL = "TAMIL" + TELUGU = "TELUGU" + BENGALI = "BENGALI" + MARATHI = "MARATHI" + GUJARATI = "GUJARATI" + KANNADA = "KANNADA" + MALAYALAM = "MALAYALAM" + PUNJABI = "PUNJABI" + ODIA = "ODIA" + ASSAMESE = "ASSAMESE" + URDU = "URDU" + OTHER = "OTHER" + + +class TargetModelType(models.TextChoices): + GPT = "GPT" + LLAMA = "LLAMA" + MISTRAL = "MISTRAL" + GEMINI = "GEMINI" + CLAUDE = "CLAUDE" + FALCON = "FALCON" + BLOOM = "BLOOM" + INDIC_LLM = "INDIC_LLM" + CUSTOM = "CUSTOM" + OTHER = "OTHER" + + class DatasetAccessType(models.TextChoices): PUBLIC = "PUBLIC" PRIVATE = "PRIVATE" @@ -169,6 +255,7 @@ class AIModelType(models.TextChoices): class AIModelStatus(models.TextChoices): + DRAFT = "DRAFT" REGISTERED = "REGISTERED" VALIDATING = "VALIDATING" ACTIVE = "ACTIVE" @@ -188,6 +275,16 @@ class AIModelProvider(models.TextChoices): HUGGINGFACE = "HUGGINGFACE" +class AIModelLifecycleStage(models.TextChoices): + DEVELOPMENT = "DEVELOPMENT", "Development" + TESTING = "TESTING", "Testing" + BETA = "BETA", "Beta Testing" + STAGING = "STAGING", "Staging" + PRODUCTION = "PRODUCTION", "Production" + DEPRECATED = "DEPRECATED", "Deprecated" + RETIRED = "RETIRED", "Retired" + + class AIModelFramework(models.TextChoices): PYTORCH = "pt", "PyTorch" TENSORFLOW = "tf", "TensorFlow" diff --git a/api/utils/keycloak_utils.py b/api/utils/keycloak_utils.py index 2daa5ebe..15cd6fda 100644 --- a/api/utils/keycloak_utils.py +++ b/api/utils/keycloak_utils.py @@ -56,7 +56,7 @@ def get_token(self, username: str, password: str) -> Dict[str, Any]: def validate_token(self, token: str) -> Dict[str, Any]: """ - Validate a Keycloak token and return the user info. + Validate a token (Django JWT or Keycloak) and return the user info. Args: token: The token to validate @@ -64,6 +64,48 @@ def validate_token(self, token: str) -> Dict[str, Any]: Returns: Dict containing the user information """ + from rest_framework_simplejwt.exceptions import InvalidToken, TokenError + from rest_framework_simplejwt.tokens import AccessToken + + # First, try to validate as Django JWT token + try: + logger.debug("Attempting to validate as Django JWT token") + access_token = AccessToken(token) # type: ignore[arg-type] + user_id = access_token.get("user_id") + + if user_id: + logger.debug(f"Valid Django JWT token for user_id: {user_id}") + try: + from authorization.models import User + + user = User.objects.get(id=user_id) + user.save() + + # NOTE: Organizations are managed in DataSpace database, not Keycloak + # Organization memberships should be created/managed through DataSpace's + # organization management interface, not during Keycloak sync + + # Log for debugging Keycloak validation") + # Return user info in Keycloak format + return { + "sub": ( + str(user.keycloak_id) + if hasattr(user, "keycloak_id") and user.keycloak_id + else str(user.id) + ), + "preferred_username": user.username, + "email": user.email, + "given_name": user.first_name, + "family_name": user.last_name, + } + except User.DoesNotExist: + logger.warning(f"User with id {user_id} not found in database") + except (TokenError, InvalidToken) as e: + logger.debug(f"Not a valid Django JWT token: {e}, trying Keycloak validation") + except Exception as e: + logger.debug(f"Error validating Django JWT: {e}, trying Keycloak validation") + + # If Django JWT validation failed, try Keycloak token validation try: # Verify the token is valid token_info = self.keycloak_openid.introspect(token) @@ -71,16 +113,46 @@ def validate_token(self, token: str) -> Dict[str, Any]: logger.warning("Token is not active") return {} - # Get user info from the token - user_info = self.keycloak_openid.userinfo(token) - return user_info + # Try to get user info from the userinfo endpoint + # If that fails (403), fall back to token introspection data + try: + user_info = self.keycloak_openid.userinfo(token) + return user_info + except KeycloakError as userinfo_error: + # If userinfo fails (e.g., 403), extract user info from token introspection + logger.warning( + f"Userinfo endpoint failed ({userinfo_error}), using token introspection data" + ) + + # Build user info from introspection response + user_info = { + "sub": token_info.get("sub"), + "preferred_username": token_info.get("username") + or token_info.get("preferred_username"), + "email": token_info.get("email"), + "email_verified": token_info.get("email_verified", False), + "name": token_info.get("name"), + "given_name": token_info.get("given_name"), + "family_name": token_info.get("family_name"), + } + + # Remove None values + user_info = {k: v for k, v in user_info.items() if v is not None} + return user_info + except KeycloakError as e: logger.error(f"Error validating token: {e}") return {} - def get_user_roles(self, token_info: dict) -> list[str]: + def get_user_roles_from_token_info(self, token_info: dict) -> list[str]: """ - Extract roles from a Keycloak token. + Extract roles from token introspection data. + + Args: + token_info: Token introspection response + + Returns: + List of role names """ roles: list[str] = [] @@ -98,61 +170,189 @@ def get_user_roles(self, token_info: dict) -> list[str]: return roles - def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: + def get_user_organizations_from_token_info(self, token_info: dict) -> List[Dict[str, Any]]: """ - Get the organizations a user belongs to from their token. - This assumes that organization information is stored in the token - as client roles or in user attributes. + Get organizations from token introspection data. + + NOTE: In DataSpace, organizations are ALWAYS managed in the database. + This method always returns an empty list. + + Args: + token_info: Token introspection response (not used) + + Returns: + Empty list - organizations are managed in DataSpace database + """ + # Organizations are managed in DataSpace database, not Keycloak + logger.debug("Organizations are managed in DataSpace DB, returning empty list") + return [] + + def get_user_roles(self, token: str) -> list[str]: + """ + Extract roles from a Keycloak token. Args: token: The user's token Returns: - List of organization information + List of role names """ try: # Decode the token to get user info token_info = self.keycloak_openid.decode_token(token) - # Get organization info from resource_access or attributes - # This implementation depends on how organizations are represented in Keycloak - # This is a simplified example - adjust based on your Keycloak configuration + roles: list[str] = [] + + # Extract realm roles + realm_access = token_info.get("realm_access", {}) + if realm_access and "roles" in realm_access: + roles.extend(realm_access["roles"]) # type: ignore[no-any-return] + + # Extract client roles resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - # Extract organization info from roles - # Format could be 'org__' or similar - organizations = [] - for role in client_roles: - if role.startswith("org_"): - parts = role.split("_") - if len(parts) >= 3: - org_id = parts[1] - role_name = parts[2] - organizations.append( - {"organization_id": org_id, "role": role_name} - ) - - return organizations - except KeycloakError as e: - logger.error(f"Error getting user organizations: {e}") + client_id = settings.KEYCLOAK_CLIENT_ID + if resource_access and client_id in resource_access: + client_roles = resource_access[client_id].get("roles", []) + roles.extend(client_roles) + + return roles + except Exception as e: + logger.error(f"Error getting user roles: {e}") return [] + def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: + """ + Get the organizations a user belongs to from their token. + + NOTE: In DataSpace, organizations and memberships are ALWAYS managed + in the database, NOT in Keycloak. This method always returns an empty list. + + Organizations should be retrieved using AuthorizationService.get_user_organizations() + which queries the OrganizationMembership table. + + Args: + token: The user's token (not used) + + Returns: + Empty list - organizations are managed in DataSpace database + """ + # Organizations are managed in DataSpace database, not Keycloak + # Always return empty list - the actual organizations will be fetched + # from the database via AuthorizationService.get_user_organizations() + logger.debug("Organizations are managed in DataSpace DB, returning empty list") + return [] + + def update_user_in_keycloak(self, user: User) -> bool: + """Update user details in Keycloak using admin credentials.""" + if not user.keycloak_id: + logger.warning("Cannot update user in Keycloak: No keycloak_id", user_id=str(user.id)) + return False + + try: + # Get admin credentials from settings + admin_username = getattr(settings, "KEYCLOAK_ADMIN_USERNAME", "") + admin_password = getattr(settings, "KEYCLOAK_ADMIN_PASSWORD", "") + # Log credential presence (not the actual values) + logger.info( + "Admin credentials check", + username_present=bool(admin_username), + password_present=bool(admin_password), + ) + + if not admin_username or not admin_password: + logger.error("Keycloak admin credentials not configured") + return False + + from keycloak import KeycloakOpenID + + # First get an admin token directly + keycloak_openid = KeycloakOpenID( + server_url=self.server_url, + client_id="admin-cli", # Special client for admin operations + realm_name="master", # Admin users are in master realm + verify=True, + ) + + # Get token + try: + token = keycloak_openid.token( + username=admin_username, + password=admin_password, + grant_type="password", + ) + access_token = token.get("access_token") + + if not access_token: + logger.error("Failed to get admin access token") + return False + + # Now use the token to update the user + import requests + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + user_data = { + "firstName": user.first_name, + "lastName": user.last_name, + "email": user.email, + "emailVerified": True, + } + + # Direct API call to update user + base_url = self.server_url.rstrip("/") # Remove any trailing slash + response = requests.put( + f"{base_url}/admin/realms/{self.realm}/users/{user.keycloak_id}", + headers=headers, + json=user_data, + ) + + if response.status_code == 204: # Success for this endpoint + logger.info( + "Successfully updated user in Keycloak", + user_id=str(user.id), + keycloak_id=user.keycloak_id, + ) + return True + else: + logger.error( + f"Failed to update user in Keycloak: {response.status_code}: {response.text}", + user_id=str(user.id), + ) + return False + + except Exception as token_error: + logger.error( + f"Error getting admin token: {str(token_error)}", + user_id=str(user.id), + ) + return False + + except Exception as e: + logger.error(f"Error updating user in Keycloak: {str(e)}", user_id=str(user.id)) + return False + @transaction.atomic def sync_user_from_keycloak( self, user_info: Dict[str, Any], roles: List[str], - organizations: List[Dict[str, Any]], + organizations: Optional[List[Dict[str, Any]]] = None, ) -> Optional[User]: """ Synchronize user information from Keycloak to Django. - Creates or updates the User and UserOrganization records. + Creates or updates the User record. + + NOTE: Organizations are ALWAYS managed in DataSpace database, not Keycloak. + Organization memberships should be created/managed through DataSpace's + organization management interface. This method does NOT sync organizations. Args: user_info: User information from Keycloak - roles: User roles from Keycloak - organizations: User organization memberships from Keycloak + roles: User roles from Keycloak (for is_staff/is_superuser only) + organizations: Ignored - organizations are managed in DataSpace database Returns: The synchronized User object or None if failed @@ -166,17 +366,36 @@ def sync_user_from_keycloak( logger.error("Missing required user information from Keycloak") return None - # Get or create the user - user, created = User.objects.update_or_create( - keycloak_id=keycloak_id, - defaults={ - "username": username, - "email": email, - "first_name": user_info.get("given_name", ""), - "last_name": user_info.get("family_name", ""), - "is_active": True, - }, - ) + # First, try to find user by keycloak_id + user = User.objects.filter(keycloak_id=keycloak_id).first() + + if not user and email: + # If not found by keycloak_id, check if user exists with same email + # and update their keycloak_id (handles pre-existing users) + user = User.objects.filter(email=email).first() + if user: + logger.info( + f"Found existing user with email {email}, updating keycloak_id to {keycloak_id}" + ) + + if user: + # Update existing user + user.keycloak_id = keycloak_id + user.username = username + user.email = email + user.first_name = user_info.get("given_name", "") or user.first_name + user.last_name = user_info.get("family_name", "") or user.last_name + user.is_active = True + else: + # Create new user + user = User( + keycloak_id=keycloak_id, + username=username, + email=email, + first_name=user_info.get("given_name", ""), + last_name=user_info.get("family_name", ""), + is_active=True, + ) # Update user roles based on Keycloak roles if "admin" in roles: @@ -188,41 +407,6 @@ def sync_user_from_keycloak( user.save() - # Update organization memberships - # First, get all existing organization memberships - existing_memberships = OrganizationMembership.objects.filter(user=user) - existing_org_ids = { - membership.organization_id for membership in existing_memberships # type: ignore[attr-defined] - } - - # Process organizations from Keycloak - for org_info in organizations: - org_id = org_info.get("organization_id") - role = org_info.get( - "role", "viewer" - ) # Default to viewer if role not specified - - # Try to get the organization - try: - organization = Organization.objects.get(id=org_id) # type: ignore[misc] - - # Create or update the membership - OrganizationMembership.objects.update_or_create( - user=user, organization=organization, defaults={"role": role} - ) - - # Remove from the set of existing memberships - if org_id in existing_org_ids: - existing_org_ids.remove(org_id) - except Organization.DoesNotExist: - logger.warning(f"Organization with ID {org_id} does not exist") - - # Remove memberships that no longer exist in Keycloak - if existing_org_ids: - OrganizationMembership.objects.filter( - user=user, organization_id__in=existing_org_ids - ).delete() - return user except Exception as e: logger.error(f"Error synchronizing user from Keycloak: {e}") diff --git a/api/utils/middleware.py b/api/utils/middleware.py index 145a5929..c3f5e81d 100644 --- a/api/utils/middleware.py +++ b/api/utils/middleware.py @@ -40,9 +40,7 @@ class CustomHttpRequest(HttpRequest): class ContextMiddleware: - def __init__( - self, get_response: Callable[[CustomHttpRequest], HttpResponse] - ) -> None: + def __init__(self, get_response: Callable[[CustomHttpRequest], HttpResponse]) -> None: self.get_response = get_response def __call__(self, request: CustomHttpRequest) -> HttpResponse: @@ -71,7 +69,10 @@ def __call__(self, request: CustomHttpRequest) -> HttpResponse: if organization_slug is None: organization: Optional[Organization] = None else: - organization = get_object_or_404(Organization, slug=organization_slug) + try: + organization = Organization.objects.get(slug=organization_slug) + except Organization.DoesNotExist: + organization = get_object_or_404(Organization, id=organization_slug) if dataspace_slug is None: dataspace: Optional[DataSpace] = None else: diff --git a/api/views/aimodel_detail.py b/api/views/aimodel_detail.py index dab62064..411aacab 100644 --- a/api/views/aimodel_detail.py +++ b/api/views/aimodel_detail.py @@ -1,8 +1,9 @@ """API view for AI Model detail.""" +import logging from typing import Any, Dict, List, Optional -import logging +from django.core.cache import cache from rest_framework import serializers, status from rest_framework.permissions import AllowAny from rest_framework.request import Request @@ -11,9 +12,11 @@ from api.models.AIModel import AIModel, ModelEndpoint +AIMODEL_DETAIL_CACHE_TTL = 60 * 15 # 15 minutes logger = logging.getLogger(__name__) + class ModelEndpointSerializer(serializers.ModelSerializer): """Serializer for Model Endpoint.""" @@ -59,6 +62,7 @@ class Meta: "tags", "sectors", "geographies", + "domain", "metadata", "status", "is_public", @@ -118,11 +122,17 @@ class AIModelDetailView(APIView): def get(self, request: Request, model_id: str) -> Response: """Get AI model details.""" try: + cache_key = f"aimodel_detail:{model_id}" + cached = cache.get(cache_key) + if cached: + return Response(cached) + model = AIModel.objects.prefetch_related( "tags", "sectors", "geographies", "endpoints", "organization", "user" ).get(id=model_id) serializer = AIModelDetailSerializer(model) + cache.set(cache_key, serializer.data, timeout=AIMODEL_DETAIL_CACHE_TTL) return Response(serializer.data) except AIModel.DoesNotExist: diff --git a/api/views/aimodel_execution.py b/api/views/aimodel_execution.py index 9c65d01e..b13ca42d 100644 --- a/api/views/aimodel_execution.py +++ b/api/views/aimodel_execution.py @@ -3,9 +3,8 @@ Handles model inference requests via ModelAPIClient and ModelHFClient. """ -from typing import Any - import logging +from typing import Any from rest_framework import status from rest_framework.decorators import api_view, permission_classes @@ -66,14 +65,45 @@ def call_aimodel(request: Request, model_id: str) -> Response: ) parameters = request.data.get("parameters", {}) + version_id = request.data.get("version_id") + + # Get the version - either specific version or primary (latest) + if version_id: + primary_version = model.versions.filter(id=version_id).first() + if not primary_version: + return Response( + {"error": f"Version with ID {version_id} not found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + # Fall back to primary (latest) version + primary_version = model.versions.filter(is_latest=True).first() + if not primary_version: + primary_version = model.versions.first() + + if not primary_version: + return Response( + {"error": "No version found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + primary_provider = primary_version.providers.filter(is_primary=True, is_active=True).first() + if not primary_provider: + primary_provider = primary_version.providers.filter(is_active=True).first() + + if not primary_provider: + return Response( + {"error": "No active provider found for this model"}, + status=status.HTTP_400_BAD_REQUEST, + ) # Route to appropriate client based on provider result: Any - if model.provider == "HUGGINGFACE": - hf_client = ModelHFClient(model) + if primary_provider.provider == "HUGGINGFACE": + hf_client = ModelHFClient(primary_provider) result = hf_client.call(input_text) else: - api_client = ModelAPIClient(model) + api_client = ModelAPIClient(primary_provider) result = api_client.call(input_text, parameters) return Response(result, status=status.HTTP_200_OK) diff --git a/api/views/auditor.py b/api/views/auditor.py new file mode 100644 index 00000000..40535ec8 --- /dev/null +++ b/api/views/auditor.py @@ -0,0 +1,326 @@ +"""REST API views for auditor management.""" + +import logging +from typing import Any, Dict, List, Optional + +from django.core.cache import cache +from django.db import transaction +from rest_framework import status, views +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.response import Response + +from api.models import Organization +from authorization.models import OrganizationMembership, Role, User + +ROLE_CACHE_TTL = 60 * 60 # 1 hour + +logger = logging.getLogger(__name__) + + +class OrganizationAuditorsView(views.APIView): + """ + View for managing auditors in an organization. + + GET: List all auditors for an organization + POST: Add a user as auditor to an organization (by user_id or email) + DELETE: Remove an auditor from an organization + """ + + permission_classes = [IsAuthenticated] + + def _get_organization(self, organization_id: str) -> Optional[Organization]: + """Get organization by ID.""" + try: + return Organization.objects.get(id=organization_id) + except Organization.DoesNotExist: + return None + + def _check_admin_permission(self, user: User, organization: Organization) -> bool: + """Check if user has admin permission for the organization.""" + if user.is_superuser: + return True + try: + membership = OrganizationMembership.objects.get(user=user, organization=organization) + # Admin role has can_change permission + return membership.role.can_change # type: ignore[return-value] + except OrganizationMembership.DoesNotExist: + return False + + def _get_auditor_role(self) -> Optional[Role]: + """Get the auditor role.""" + cached_id = cache.get("role_id:auditor") + if cached_id: + try: + return Role.objects.get(pk=cached_id) + except Role.DoesNotExist: + cache.delete("role_id:auditor") + try: + role = Role.objects.get(name="auditor") + cache.set("role_id:auditor", role.pk, timeout=ROLE_CACHE_TTL) + return role + except Role.DoesNotExist: + logger.error("Auditor role not found. Please run migrations.") + return None + + def get(self, request: Request, organization_id: str) -> Response: + """Get all auditors for an organization.""" + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has permission to view organization members + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to view auditors for this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + # Get all auditors for this organization + auditor_memberships = OrganizationMembership.objects.filter( + organization=organization, role=auditor_role + ).select_related("user") + + auditors: List[Dict[str, Any]] = [] + for membership in auditor_memberships: # type: OrganizationMembership + user: User = membership.user # type: ignore[assignment] + auditors.append( + { + "id": str(user.id), + "username": user.username, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "profile_picture": user.profile_picture.url if user.profile_picture else None, + "joined_at": ( + membership.created_at.isoformat() if membership.created_at else None + ), + } + ) + + return Response( + { + "organization_id": str(organization.id), + "organization_name": organization.name, + "auditors": auditors, + "count": len(auditors), + } + ) + + @transaction.atomic + def post(self, request: Request, organization_id: str) -> Response: + """ + Add a user as auditor to an organization. + + Request body can contain either: + - user_id: ID of an existing user + - email: Email of a user to add (will look up user by email) + """ + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has admin permission + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to add auditors to this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + user_id = request.data.get("user_id") + email = request.data.get("email") + + if not user_id and not email: + return Response( + {"error": "Either user_id or email is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Find the user + target_user: Optional[User] = None + if user_id: + try: + target_user = User.objects.get(id=user_id) + except User.DoesNotExist: + return Response( + {"error": f"User with ID {user_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + elif email: + try: + target_user = User.objects.get(email=email) + except User.DoesNotExist: + return Response( + { + "error": f"User with email {email} not found. The user must have an account in CivicDataSpace first." + }, + status=status.HTTP_404_NOT_FOUND, + ) + + if not target_user: + return Response( + {"error": "Could not find user"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user is already a member of the organization + existing_membership = OrganizationMembership.objects.filter( + user=target_user, organization=organization + ).first() + + if existing_membership: + if existing_membership.role == auditor_role: + return Response( + {"error": "User is already an auditor for this organization"}, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + # User has a different role, update to auditor + # Note: This might not be desired behavior - you may want to keep existing role + # For now, we'll add them as auditor (they can have multiple roles in future) + return Response( + {"error": f"User is already a member of this organization with role '{existing_membership.role.name}'"}, # type: ignore[attr-defined] + status=status.HTTP_400_BAD_REQUEST, + ) + + # Create the membership + membership = OrganizationMembership.objects.create( + user=target_user, + organization=organization, + role=auditor_role, + ) + + logger.info( + f"Added user {target_user.username} as auditor to organization {organization.name}" + ) + + return Response( + { + "success": True, + "message": f"User {target_user.username} added as auditor", + "auditor": { + "id": target_user.id, + "username": target_user.username, + "email": target_user.email, + "first_name": target_user.first_name, + "last_name": target_user.last_name, + "joined_at": membership.created_at.isoformat(), + }, + }, + status=status.HTTP_201_CREATED, + ) + + @transaction.atomic + def delete(self, request: Request, organization_id: str) -> Response: + """Remove an auditor from an organization.""" + organization = self._get_organization(organization_id) + if not organization: + return Response( + {"error": f"Organization with ID {organization_id} not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Check if user has admin permission + if not self._check_admin_permission(request.user, organization): # type: ignore[arg-type] + return Response( + {"error": "You don't have permission to remove auditors from this organization"}, + status=status.HTTP_403_FORBIDDEN, + ) + + user_id = request.data.get("user_id") + if not user_id: + return Response( + {"error": "user_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + auditor_role = self._get_auditor_role() + if not auditor_role: + return Response( + {"error": "Auditor role not configured"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + try: + membership = OrganizationMembership.objects.get( + user_id=user_id, organization=organization, role=auditor_role + ) + username = membership.user.username # type: ignore[attr-defined] + membership.delete() + + logger.info(f"Removed auditor {username} from organization {organization.name}") + + return Response( + { + "success": True, + "message": f"Auditor {username} removed from organization", + } + ) + except OrganizationMembership.DoesNotExist: + return Response( + {"error": "User is not an auditor for this organization"}, + status=status.HTTP_404_NOT_FOUND, + ) + + +class SearchUserByEmailView(views.APIView): + """ + Search for a user by email. + Used to find users before adding them as auditors. + """ + + permission_classes = [IsAuthenticated] + + def get(self, request: Request) -> Response: + """Search for a user by email.""" + email = request.query_params.get("email") + if not email: + return Response( + {"error": "email query parameter is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + user = User.objects.get(email=email) + return Response( + { + "found": True, + "user": { + "id": user.id, + "username": user.username, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "profile_picture": ( + user.profile_picture.url if user.profile_picture else None + ), + }, + } + ) + except User.DoesNotExist: + return Response( + { + "found": False, + "message": f"No user found with email {email}", + } + ) diff --git a/api/views/auth.py b/api/views/auth.py index c5c3bb2f..94d98dc7 100644 --- a/api/views/auth.py +++ b/api/views/auth.py @@ -32,9 +32,12 @@ def post(self, request: Request) -> Response: status=status.HTTP_401_UNAUTHORIZED, ) - # Get user roles and organizations from the token - roles = keycloak_manager.get_user_roles(keycloak_token) - organizations = keycloak_manager.get_user_organizations(keycloak_token) + # Get token introspection data for roles and organizations + token_info = keycloak_manager.keycloak_openid.introspect(keycloak_token) + + # Get user roles and organizations from the token introspection data + roles = keycloak_manager.get_user_roles_from_token_info(token_info) + organizations = keycloak_manager.get_user_organizations_from_token_info(token_info) # Sync the user information with our database user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) @@ -93,6 +96,12 @@ def get(self, request: Request) -> Response: "id": org.organization.id, # type: ignore[attr-defined] "name": org.organization.name, # type: ignore[attr-defined] "role": org.role.name, # type: ignore[attr-defined] + "description": org.organization.description, # type: ignore[attr-defined] + "logo": org.organization.logo.url if org.organization.logo else None, # type: ignore[attr-defined] + "homepage": org.organization.homepage, # type: ignore[attr-defined] + "created": org.organization.created, # type: ignore[attr-defined] + "updated": org.organization.modified, # type: ignore[attr-defined] + "slug": org.organization.slug, # type: ignore[attr-defined] } for org in user.organizationmembership_set.all() # type: ignore[union-attr, arg-type] ], diff --git a/api/views/paginated_elastic_view.py b/api/views/paginated_elastic_view.py index 038f8d78..7f892bb5 100644 --- a/api/views/paginated_elastic_view.py +++ b/api/views/paginated_elastic_view.py @@ -47,9 +47,7 @@ def get_search(self) -> SearchType: """Get search instance.""" if hasattr(self.document_class, "search"): return self.document_class.search() # type: ignore - raise AttributeError( - f"{self.document_class.__name__} does not have a search method" - ) + raise AttributeError(f"{self.document_class.__name__} does not have a search method") def get(self, request: HttpRequest) -> Response: """Handle GET request and return paginated search results.""" @@ -58,9 +56,8 @@ def get(self, request: HttpRequest) -> Response: cache_key = self._generate_cache_key(request) cached_result: Optional[Dict[str, Any]] = cache.get(cache_key) - # TODO: Fix cache issues on different model updates - # if cached_result: - # return Response(cached_result) + if cached_result: + return Response(cached_result) # Original search logic query: str = request.GET.get("query", "") @@ -99,10 +96,10 @@ def get(self, request: HttpRequest) -> Response: aggregations.pop("metadata") for agg in metadata_aggregations: label: str = agg["key"]["metadata_label"] - value: str = agg["key"].get("metadata_value", "") - if label not in aggregations: - aggregations[label] = {} - aggregations[label][value] = agg["doc_count"] + value: str = agg["key"].get("metadata_value", "") + if label not in aggregations: + aggregations[label] = {} + aggregations[label][value] = agg["doc_count"] if "catalogs" in aggregations: aggregations.pop("catalogs") @@ -170,9 +167,7 @@ def get(self, request: HttpRequest) -> Response: aggregations["running_status"][agg["key"]] = agg["doc_count"] if "is_individual_usecase" in aggregations: - is_individual_usecase_agg = aggregations["is_individual_usecase"][ - "buckets" - ] + is_individual_usecase_agg = aggregations["is_individual_usecase"]["buckets"] aggregations.pop("is_individual_usecase") aggregations["is_individual_usecase"] = {} for agg in is_individual_usecase_agg: @@ -194,6 +189,7 @@ def get(self, request: HttpRequest) -> Response: def _generate_cache_key(self, request: HttpRequest) -> str: """Generate a unique cache key based on request parameters and cache version.""" params: Dict[str, str] = { + "document_type": self.document_class.__name__, "query": request.GET.get("query", ""), "page": request.GET.get("page", "1"), "size": request.GET.get("size", "10"), diff --git a/api/views/search_aimodel.py b/api/views/search_aimodel.py index 63f8b7ba..083cc65a 100644 --- a/api/views/search_aimodel.py +++ b/api/views/search_aimodel.py @@ -21,10 +21,15 @@ class AIModelDocumentSerializer(serializers.ModelSerializer): """Serializer for AIModel document.""" tags = serializers.ListField() + sectors = serializers.ListField() + geographies = serializers.ListField() supported_languages = serializers.ListField() is_individual_model = serializers.BooleanField() has_active_endpoints = serializers.BooleanField() endpoint_count = serializers.IntegerField() + version_count = serializers.IntegerField() + lifecycle_stage = serializers.CharField() + all_providers = serializers.ListField() class OrganizationSerializer(serializers.Serializer): name = serializers.CharField() @@ -42,9 +47,31 @@ class EndpointSerializer(serializers.Serializer): is_primary = serializers.BooleanField() is_active = serializers.BooleanField() + class ProviderSerializer(serializers.Serializer): + id = serializers.IntegerField() + provider = serializers.CharField() + provider_model_id = serializers.CharField() + is_primary = serializers.BooleanField() + is_active = serializers.BooleanField() + + class VersionSerializer(serializers.Serializer): + id = serializers.IntegerField() + version = serializers.CharField() + version_notes = serializers.CharField(allow_blank=True) + lifecycle_stage = serializers.CharField() + is_latest = serializers.BooleanField() + status = serializers.CharField() + supports_streaming = serializers.BooleanField() + max_tokens = serializers.IntegerField(allow_null=True) + supported_languages = serializers.ListField() + created_at = serializers.DateTimeField() + updated_at = serializers.DateTimeField() + providers = serializers.ListField() + organization = OrganizationSerializer(allow_null=True) user = UserSerializer(allow_null=True) endpoints = EndpointSerializer(many=True) + versions = VersionSerializer(many=True) class Meta: model = AIModel @@ -61,6 +88,8 @@ class Meta: "is_public", "is_active", "tags", + "sectors", + "geographies", "supported_languages", "supports_streaming", "max_tokens", @@ -74,9 +103,14 @@ class Meta: "is_individual_model", "has_active_endpoints", "endpoint_count", + "domain", + "version_count", + "lifecycle_stage", + "all_providers", "organization", "user", "endpoints", + "versions", ] @@ -91,9 +125,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -115,10 +147,15 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: "provider": "terms", "status": "terms", "tags.raw": "terms", + "sectors.raw": "terms", + "geographies.raw": "terms", "supported_languages": "terms", "is_public": "terms", "is_active": "terms", "supports_streaming": "terms", + "lifecycle_stage": "terms", + "all_providers": "terms", + "domain": "terms", } return searchable_fields, aggregations @@ -134,19 +171,13 @@ def add_aggregations(self, search: Search) -> Search: ) return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_aimodel"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_aimodel"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] for field in self.searchable_fields: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -164,11 +195,22 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter("terms", **{raw_filter: filter_values}) else: search = search.filter("term", **{filter_key: filters[filter_key]}) + elif filter_key in ["sectors", "geographies"]: + # Handle multi-value filters for sectors and geographies + raw_filter = filter_key + ".raw" + if raw_filter in self.aggregations: + filter_values = filters[filter_key].split(",") + search = search.filter("terms", **{raw_filter: filter_values}) + else: + search = search.filter("term", **{filter_key: filters[filter_key]}) elif filter_key in [ "model_type", "provider", "status", "supported_languages", + "lifecycle_stage", + "all_providers", + "domain", ]: # Handle single or multi-value filters filter_values = filters[filter_key].split(",") diff --git a/api/views/search_collaborative.py b/api/views/search_collaborative.py new file mode 100644 index 00000000..38d45076 --- /dev/null +++ b/api/views/search_collaborative.py @@ -0,0 +1,360 @@ +import ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import structlog +from django.core.cache import cache +from elasticsearch_dsl import A +from elasticsearch_dsl import Q as ESQ +from elasticsearch_dsl import Search +from elasticsearch_dsl.query import Query as ESQuery +from rest_framework import serializers +from rest_framework.permissions import AllowAny + +from api.models import Collaborative, CollaborativeMetadata, Geography, Metadata +from api.utils.telemetry_utils import trace_method +from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView +from search.documents import CollaborativeDocument + +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + +logger = structlog.get_logger(__name__) + + +class MetadataSerializer(serializers.Serializer): + label = serializers.CharField(allow_blank=True) # type: ignore + + +class CollaborativeMetadataSerializer(serializers.ModelSerializer): + metadata_item = MetadataSerializer() + + class Meta: + model = CollaborativeMetadata + fields = ["metadata_item", "value"] + + def to_representation(self, instance: CollaborativeMetadata) -> Dict[str, Any]: + representation = super().to_representation(instance) + + if isinstance(representation["value"], str): + try: + value_list = ast.literal_eval(representation["value"]) + if isinstance(value_list, list): + representation["value"] = ", ".join(str(x) for x in value_list) + except (ValueError, SyntaxError): + pass + + return cast(Dict[str, Any], representation) + + def to_internal_value(self, data: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(data.get("value"), str): + try: + value = data["value"] + data["value"] = value.split(", ") if value else [] + except (ValueError, SyntaxError): + pass + + return cast(Dict[str, Any], super().to_internal_value(data)) + + +class CollaborativeDocumentSerializer(serializers.ModelSerializer): + metadata = CollaborativeMetadataSerializer(many=True) + tags = serializers.ListField() + sectors = serializers.ListField(default=[]) + geographies = serializers.ListField(default=[]) + slug = serializers.CharField() + is_individual_collaborative = serializers.BooleanField() + website = serializers.CharField(required=False, allow_blank=True) + contact_email = serializers.EmailField(required=False, allow_blank=True) + platform_url = serializers.CharField(required=False, allow_blank=True) + started_on = serializers.DateTimeField(required=False, allow_null=True) + completed_on = serializers.DateTimeField(required=False, allow_null=True) + + class OrganizationSerializer(serializers.Serializer): + name = serializers.CharField() + logo = serializers.CharField() + + class UserSerializer(serializers.Serializer): + name = serializers.CharField() + bio = serializers.CharField() + profile_picture = serializers.CharField() + + class ContributorSerializer(serializers.Serializer): + name = serializers.CharField() + bio = serializers.CharField() + profile_picture = serializers.CharField() + + class RelatedOrganizationSerializer(serializers.Serializer): + name = serializers.CharField() + logo = serializers.CharField() + relationship_type = serializers.CharField() + + class DatasetSerializer(serializers.Serializer): + title = serializers.CharField() + description = serializers.CharField() + slug = serializers.CharField() + + class UseCaseSerializer(serializers.Serializer): + title = serializers.CharField() + summary = serializers.CharField() + slug = serializers.CharField() + + organization = OrganizationSerializer() + user = UserSerializer() + contributors = ContributorSerializer(many=True) + organizations = RelatedOrganizationSerializer(many=True) + datasets = DatasetSerializer(many=True) + use_cases = UseCaseSerializer(many=True) + + class Meta: + model = Collaborative + fields = [ + "id", + "title", + "summary", + "slug", + "created", + "modified", + "status", + "metadata", + "tags", + "sectors", + "geographies", + "is_individual_collaborative", + "organization", + "user", + "contributors", + "organizations", + "datasets", + "use_cases", + "website", + "contact_email", + "platform_url", + "started_on", + "completed_on", + ] + + +class SearchCollaborative(PaginatedElasticSearchAPIView): + serializer_class = CollaborativeDocumentSerializer + document_class = CollaborativeDocument + permission_classes = [AllowAny] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.searchable_fields: List[str] + self.aggregations: Dict[str, str] + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() + self.logger = structlog.get_logger(__name__) + + @trace_method( + name="get_searchable_and_aggregations", + attributes={"component": "search_collaborative"}, + ) + def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "collaborative_search_metadata_config" + ) + if cached: + return cached + + searchable_fields = [ + "title", + "summary", + "tags", + "sectors", + "user.name", + "organization.name", + "contributors.name", + "datasets.title", + "datasets.description", + "use_cases.title", + "use_cases.summary", + "metadata.value", + ] + + aggregations: Dict[str, str] = { + "tags.raw": "terms", + "sectors.raw": "terms", + "geographies.raw": "terms", + "status": "terms", + } + + filterable_metadata = Metadata.objects.filter(filterable=True).all() + for metadata in filterable_metadata: + aggregations[f"metadata.{metadata.label}"] = "terms" # type: ignore + + result = (searchable_fields, aggregations) + cache.set("collaborative_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result + + @trace_method(name="add_aggregations", attributes={"component": "search_collaborative"}) + def add_aggregations(self, search: Search) -> Search: + aggregate_fields: List[str] = [] + for aggregation_field in self.aggregations: + if aggregation_field.startswith("metadata."): + aggregate_fields.append(aggregation_field.split(".")[1]) + else: + search.aggs.bucket( + aggregation_field.replace(".raw", ""), + self.aggregations[aggregation_field], + field=aggregation_field, + ) + + if aggregate_fields: + filterable_metadata: List[str] = ( + cache.get("collaborative_filterable_metadata_labels") or [] + ) + if not filterable_metadata: + metadata_qs = Metadata.objects.filter(filterable=True) + filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + cache.set( + "collaborative_filterable_metadata_labels", + filterable_metadata, + timeout=METADATA_CACHE_TTL, + ) + + metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") + composite_agg = A( + "composite", + sources=[ + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, + {"metadata_value": {"terms": {"field": "metadata.value"}}}, + ], + size=10000, + ) + metadata_filter = A( + "filter", + { + "bool": { + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] + } + }, + ) + metadata_bucket.bucket("filtered_metadata", metadata_filter).bucket( + "composite_agg", composite_agg + ) + + return search + + @trace_method(name="generate_q_expression", attributes={"component": "search_collaborative"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: + if query: + queries: List[ESQuery] = [] + for field in self.searchable_fields: + if field.startswith("datasets."): + queries.append( + ESQ( + "nested", + path="datasets", + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + elif field.startswith("use_cases."): + queries.append( + ESQ( + "nested", + path="use_cases", + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + elif ( + field.startswith("user.") + or field.startswith("organization.") + or field.startswith("contributors.") + or field.startswith("organizations.") + ): + path = field.split(".")[0] + queries.append( + ESQ( + "nested", + path=path, + query=ESQ( + "bool", + should=[ + ESQ("wildcard", **{field: {"value": f"*{query}*"}}), + ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}), + ], + ), + ) + ) + else: + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) + else: + queries = [ESQ("match_all")] + + return ESQ("bool", should=queries, minimum_should_match=1) + + @trace_method(name="add_filters", attributes={"component": "search_collaborative"}) + def add_filters(self, filters: Dict[str, str], search: Search) -> Search: + excluded_labels: List[str] = cache.get("collaborative_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "collaborative_non_filter_metadata_labels", + excluded_labels, + timeout=METADATA_CACHE_TTL, + ) + + for filter in filters: + if filter in excluded_labels: + continue + elif filter in ["tags", "sectors", "geographies"]: + raw_filter = filter + ".raw" + if raw_filter in self.aggregations: + filter_values = filters[filter].split(",") + + if filter == "geographies": + filter_values = Geography.get_geography_names_with_descendants( + filter_values + ) + + search = search.filter("terms", **{raw_filter: filter_values}) + else: + search = search.filter("term", **{filter: filters[filter]}) + elif filter in ["status", "is_individual_collaborative"]: + search = search.filter("term", **{filter: filters[filter]}) + elif filter in ["user.name", "organization.name"]: + path = filter.split(".")[0] + search = search.filter( + "nested", + path=path, + query={"bool": {"must": {"term": {filter: filters[filter]}}}}, + ) + elif filter in ["datasets.slug", "use_cases.slug"]: + path = filter.split(".")[0] + search = search.filter( + "nested", + path=path, + query={"bool": {"must": {"term": {filter: filters[filter]}}}}, + ) + else: + search = search.filter( + "nested", + path="metadata", + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, + ) + return search + + @trace_method(name="add_sort", attributes={"component": "search_collaborative"}) + def add_sort(self, sort: str, search: Search, order: str) -> Search: + if sort == "alphabetical": + search = search.sort({"title.raw": {"order": order}}) + elif sort == "recent": + search = search.sort({"modified": {"order": order}}) + elif sort == "started": + search = search.sort({"started_on": {"order": order}}) + elif sort == "completed": + search = search.sort({"completed_on": {"order": order}}) + return search diff --git a/api/views/search_dataset.py b/api/views/search_dataset.py index d1064e36..680d5093 100644 --- a/api/views/search_dataset.py +++ b/api/views/search_dataset.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast import structlog +from django.core.cache import cache from elasticsearch_dsl import A from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search @@ -14,6 +15,8 @@ from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView from search.documents import DatasetDocument +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + logger = structlog.get_logger(__name__) @@ -56,6 +59,21 @@ def to_internal_value(self, data: Dict[str, Any]) -> Dict[str, Any]: return cast(Dict[str, Any], super().to_internal_value(data)) +class PromptMetadataSerializer(serializers.Serializer): + """Serializer for PromptMetadata in search results.""" + + task_type = serializers.CharField(allow_null=True) + target_languages = serializers.ListField(child=serializers.CharField(), allow_null=True) + domain = serializers.CharField(allow_null=True) + target_model_types = serializers.ListField(child=serializers.CharField(), allow_null=True) + prompt_format = serializers.CharField(allow_null=True) + has_system_prompt = serializers.BooleanField(allow_null=True) + has_example_responses = serializers.BooleanField(allow_null=True) + avg_prompt_length = serializers.IntegerField(allow_null=True) + prompt_count = serializers.IntegerField(allow_null=True) + use_case = serializers.CharField(allow_null=True) + + class DatasetDocumentSerializer(serializers.ModelSerializer): """Serializer for Dataset document.""" @@ -70,6 +88,8 @@ class DatasetDocumentSerializer(serializers.ModelSerializer): download_count = serializers.IntegerField() trending_score = serializers.FloatField(required=False) is_individual_dataset = serializers.BooleanField() + dataset_type = serializers.CharField(required=False, default="DATA") + prompt_metadata = PromptMetadataSerializer(required=False, allow_null=True) class OrganizationSerializer(serializers.Serializer): name = serializers.CharField() @@ -93,6 +113,7 @@ class Meta: "created", "modified", "status", + "dataset_type", "metadata", "tags", "sectors", @@ -105,6 +126,7 @@ class Meta: "is_individual_dataset", "organization", "user", + "prompt_metadata", ] @@ -119,9 +141,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -130,6 +150,12 @@ def __init__(self, **kwargs: Any) -> None: ) def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: """Get searchable fields and aggregations for the search.""" + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "dataset_search_metadata_config" + ) + if cached: + return cached + enabled_metadata = Metadata.objects.filter(enabled=True).all() searchable_fields: List[str] = [] searchable_fields.extend( @@ -148,11 +174,14 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: "formats.raw": "terms", "catalogs.raw": "terms", "geographies.raw": "terms", + "dataset_type": "terms", } for metadata in enabled_metadata: # type: Metadata if metadata.filterable: aggregations[f"metadata.{metadata.label}"] = "terms" - return searchable_fields, aggregations + result = (searchable_fields, aggregations) + cache.set("dataset_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result @trace_method(name="add_aggregations", attributes={"component": "search_dataset"}) def add_aggregations(self, search: Search) -> Search: @@ -175,11 +204,7 @@ def add_aggregations(self, search: Search) -> Search: metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") composite_sources = [ - { - "metadata_label": { - "terms": {"field": "metadata.metadata_item.label"} - } - }, + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, {"metadata_value": {"terms": {"field": "metadata.value"}}}, ] composite_agg = A( @@ -191,13 +216,7 @@ def add_aggregations(self, search: Search) -> Search: "filter", { # type: ignore[arg-type] "bool": { - "must": [ - { - "terms": { - "metadata.metadata_item.label": filterable_metadata - } - } - ] + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] } }, ) @@ -207,19 +226,13 @@ def add_aggregations(self, search: Search) -> Search: return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_dataset"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_dataset"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] for field in self.searchable_fields: - if field.startswith("resources.name") or field.startswith( - "resources.description" - ): + if field.startswith("resources.name") or field.startswith("resources.description"): queries.append( ESQ( "nested", @@ -230,18 +243,14 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), ) ) else: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -250,12 +259,48 @@ def generate_q_expression( @trace_method(name="add_filters", attributes={"component": "search_dataset"}) def add_filters(self, filters: Dict[str, str], search: Search) -> Search: """Add filters to the search query.""" - non_filter_metadata = Metadata.objects.filter(filterable=False).all() - excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + excluded_labels: List[str] = cache.get("dataset_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "dataset_non_filter_metadata_labels", excluded_labels, timeout=METADATA_CACHE_TTL + ) for filter in filters: if filter in excluded_labels: continue + elif filter == "dataset_type": + # Filter by dataset type (DATA or PROMPT) + search = search.filter("term", dataset_type=filters[filter]) + elif filter == "task_type": + # Filter by prompt task type (nested in prompt_metadata) + search = search.filter( + "nested", + path="prompt_metadata", + query={ + "bool": {"must": {"term": {"prompt_metadata.task_type": filters[filter]}}} + }, + ) + elif filter == "domain": + # Filter by prompt domain (nested in prompt_metadata) + search = search.filter( + "nested", + path="prompt_metadata", + query={"bool": {"must": {"term": {"prompt_metadata.domain": filters[filter]}}}}, + ) + elif filter == "target_languages": + # Filter by target languages (nested in prompt_metadata) + filter_values = filters[filter].split(",") + search = search.filter( + "nested", + path="prompt_metadata", + query={ + "bool": { + "must": {"terms": {"prompt_metadata.target_languages": filter_values}} + } + }, + ) elif filter in ["tags", "sectors", "formats", "catalogs", "geographies"]: raw_filter = filter + ".raw" if raw_filter in self.aggregations: @@ -274,9 +319,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter( "nested", path="metadata", - query={ - "bool": {"must": {"term": {f"metadata.value": filters[filter]}}} - }, + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, ) return search diff --git a/api/views/search_publisher.py b/api/views/search_publisher.py new file mode 100644 index 00000000..ff72a9dc --- /dev/null +++ b/api/views/search_publisher.py @@ -0,0 +1,306 @@ +from typing import Any, Dict, List, Tuple + +import structlog +from elasticsearch_dsl import Q as ESQ +from elasticsearch_dsl import Search +from rest_framework import serializers +from rest_framework.permissions import AllowAny +from rest_framework.response import Response + +from api.utils.telemetry_utils import trace_method, track_metrics +from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView +from search.documents import OrganizationPublisherDocument, UserPublisherDocument + +logger = structlog.get_logger(__name__) + + +class PublisherDocumentSerializer(serializers.Serializer): + """Serializer for Publisher document (both Organization and User).""" + + id = serializers.CharField() + name = serializers.CharField() + description = serializers.CharField() + publisher_type = serializers.CharField() # 'organization' or 'user' + logo = serializers.CharField(required=False) + slug = serializers.CharField(required=False) + created = serializers.DateTimeField(required=False) + modified = serializers.DateTimeField(required=False) + + # Counts + published_datasets_count = serializers.IntegerField() + published_usecases_count = serializers.IntegerField() + members_count = serializers.IntegerField(required=False) # Only for organizations + contributed_sectors_count = serializers.IntegerField() + + # Organization specific fields + homepage = serializers.CharField(required=False) + contact_email = serializers.CharField(required=False) + organization_types = serializers.CharField(required=False) + github_profile = serializers.CharField(required=False) + linkedin_profile = serializers.CharField(required=False) + twitter_profile = serializers.CharField(required=False) + location = serializers.CharField(required=False) + + # User specific fields + bio = serializers.CharField(required=False) + profile_picture = serializers.CharField(required=False) + username = serializers.CharField(required=False) + email = serializers.CharField(required=False) + first_name = serializers.CharField(required=False) + last_name = serializers.CharField(required=False) + full_name = serializers.CharField(required=False) + + # Search fields + sectors = serializers.ListField(required=False) + + +class SearchPublisher(PaginatedElasticSearchAPIView): + """API view for searching publishers (organizations and users).""" + + serializer_class = PublisherDocumentSerializer + permission_classes = [AllowAny] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.logger = structlog.get_logger(__name__) + + def get_document_classes(self) -> List[Any]: + """Return the document classes to search.""" + return [OrganizationPublisherDocument, UserPublisherDocument] + + def get_index_names(self) -> List[str]: + """Get the index names for publisher search.""" + from DataSpace import settings + + org_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.OrganizationPublisherDocument", + "organization_publisher", + ) + user_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.UserPublisherDocument", "user_publisher" + ) + return [org_index, user_index] + + @trace_method(name="build_query", attributes={"component": "publisher_search"}) + def build_query(self, query: str) -> ESQ: + """Build the Elasticsearch query for publisher search.""" + if not query: + return ESQ("match_all") + + # Multi-field search with boosting + queries = [ + ESQ( + "multi_match", + query=query, + fields=["name^3", "full_name^3"], # Boost name fields + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=["description^2", "bio^2"], # Boost description/bio + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=["sectors^2"], # Boost sectors + fuzziness="AUTO", + ), + ESQ( + "multi_match", + query=query, + fields=[ + "username", + "email", + "location", + "organization_types", + "first_name", + "last_name", + ], + fuzziness="AUTO", + ), + ] + + return ESQ("bool", should=queries, minimum_should_match=1) + + @trace_method(name="apply_filters", attributes={"component": "publisher_search"}) + def apply_filters(self, search: Search, filters: Dict[str, str]) -> Search: + """Apply filters to the search query.""" + + if "publisher_type" in filters: + # Filter by publisher type (organization or user) + search = search.filter("term", publisher_type=filters["publisher_type"]) + + if "sectors" in filters: + # Filter by sectors + filter_values = filters["sectors"].split(",") + search = search.filter("terms", **{"sectors.raw": filter_values}) + + if "organization_types" in filters: + # Filter by organization types + search = search.filter("term", organization_types=filters["organization_types"]) + + if "location" in filters: + # Filter by location (fuzzy match) + search = search.filter("match", location=filters["location"]) + + return search + + @trace_method(name="build_aggregations", attributes={"component": "publisher_search"}) + def build_aggregations(self, search: Search) -> Search: + """Build aggregations for faceted search.""" + + # Publisher type aggregation + search.aggs.bucket("publisher_type", "terms", field="publisher_type") + + # Sectors aggregation + search.aggs.bucket("sectors", "terms", field="sectors.raw", size=50) + + # Organization types aggregation + search.aggs.bucket("organization_types", "terms", field="organization_types", size=20) + + # Location aggregation (top 20 locations) + search.aggs.bucket("locations", "terms", field="location.raw", size=20) + + return search + + @trace_method(name="apply_sorting", attributes={"component": "publisher_search"}) + def apply_sorting(self, search: Search, sort_by: str) -> Search: + """Apply sorting to the search query.""" + + if sort_by == "alphabetical": + search = search.sort("name.raw") + elif sort_by == "datasets_count": + search = search.sort({"published_datasets_count": {"order": "desc"}}) + elif sort_by == "usecases_count": + search = search.sort({"published_usecases_count": {"order": "desc"}}) + elif sort_by == "total_contributions": + # Sort by total datasets + usecases + search = search.sort( + { + "_script": { + "type": "number", + "script": { + "source": "doc['published_datasets_count'].value + doc['published_usecases_count'].value" + }, + "order": "desc", + } + } + ) + elif sort_by == "members_count": + # Only applicable to organizations + search = search.sort({"members_count": {"order": "desc"}}) + elif sort_by == "recent": + search = search.sort({"created": {"order": "desc"}}) + else: + # Default: relevance score + pass + + return search + + @trace_method(name="perform_search", attributes={"component": "publisher_search"}) + def perform_search( + self, + query: str, + filters: Dict[str, str], + page: int, + size: int, + sort_by: str = "relevance", + ) -> Tuple[List[Dict[str, Any]], int, Dict[str, Any]]: + """Perform the publisher search.""" + + # Get index names + index_names = self.get_index_names() + + if not index_names: + return [], 0, {} + + # Create multi-index search + search = Search(index=index_names) + + # Build and apply query + q = self.build_query(query) + search = search.query(q) + + # Apply filters + search = self.apply_filters(search, filters) + + # Apply sorting + search = self.apply_sorting(search, sort_by) + + # Build aggregations + search = self.build_aggregations(search) + + # Pagination + start = (page - 1) * size + search = search[start : start + size] + + # Execute search + try: + response = search.execute() + except Exception as e: + self.logger.error("publisher_search_error", error=str(e), exc_info=True) + return [], 0, {} + + # Process results + results = [] + for hit in response: + result = hit.to_dict() + result["_score"] = hit.meta.score + result["_index"] = hit.meta.index + results.append(result) + + # Process aggregations + aggregations: Dict[str, Any] = {} + if hasattr(response, "aggregations"): + aggs_dict = response.aggregations.to_dict() + + for agg_name in ["publisher_type", "sectors", "organization_types", "locations"]: + if agg_name in aggs_dict: + aggregations[agg_name] = {} + for bucket in aggs_dict[agg_name]["buckets"]: + aggregations[agg_name][bucket["key"]] = bucket["doc_count"] + + total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results) + + return results, total, aggregations + + @trace_method(name="get", attributes={"component": "publisher_search"}) + @track_metrics(name="publisher_search") + def get(self, request: Any) -> Response: + """Handle GET request and return search results.""" + try: + query: str = request.GET.get("query", "") + page: int = int(request.GET.get("page", 1)) + size: int = int(request.GET.get("size", 10)) + sort_by: str = request.GET.get("sort", "relevance") + + # Handle filters + filters: Dict[str, str] = {} + for key, values in request.GET.lists(): + if key not in ["query", "page", "size", "sort"]: + if len(values) > 1: + filters[key] = ",".join(values) + else: + filters[key] = values[0] + + # Perform search + results, total, aggregations = self.perform_search(query, filters, page, size, sort_by) + + # Serialize results + serializer = self.serializer_class(results, many=True) + + return Response( + { + "results": serializer.data, + "total": total, + "page": page, + "size": size, + "aggregations": aggregations, + } + ) + + except Exception as e: + self.logger.error("publisher_search_error", error=str(e), exc_info=True) + return Response({"error": "An internal error has occurred."}, status=500) diff --git a/api/views/search_unified.py b/api/views/search_unified.py index a6b42577..b84c6daf 100644 --- a/api/views/search_unified.py +++ b/api/views/search_unified.py @@ -1,12 +1,11 @@ """Unified search view that searches across datasets, usecases, and aimodels.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple import structlog -from elasticsearch import Elasticsearch +from django.core.cache import cache from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search -from elasticsearch_dsl.connections import connections from rest_framework import serializers from rest_framework.permissions import AllowAny from rest_framework.response import Response @@ -14,9 +13,18 @@ from api.models import Dataset, Geography, Metadata, UseCase from api.models.AIModel import AIModel +from api.models.Collaborative import Collaborative +from api.signals.dataset_signals import SEARCH_CACHE_VERSION_KEY from api.utils.telemetry_utils import trace_method from DataSpace import settings -from search.documents import AIModelDocument, DatasetDocument, UseCaseDocument +from search.documents import ( + AIModelDocument, + CollaborativeDocument, + DatasetDocument, + OrganizationPublisherDocument, + UseCaseDocument, + UserPublisherDocument, +) logger = structlog.get_logger(__name__) @@ -25,7 +33,9 @@ class UnifiedSearchResultSerializer(serializers.Serializer): """Serializer for unified search results.""" id = serializers.CharField() - type = serializers.CharField() # 'dataset', 'usecase', or 'aimodel' + type = ( + serializers.CharField() + ) # 'dataset', 'usecase', 'aimodel', 'collaborative', or 'publisher' title = serializers.CharField() description = serializers.CharField() slug = serializers.CharField(required=False) @@ -70,6 +80,26 @@ class UserSerializer(serializers.Serializer): provider = serializers.CharField(required=False) is_individual_model = serializers.BooleanField(required=False) + # Collaborative specific + is_individual_collaborative = serializers.BooleanField(required=False) + website = serializers.CharField(required=False) + contact_email = serializers.CharField(required=False) + platform_url = serializers.CharField(required=False) + started_on = serializers.DateTimeField(required=False) + completed_on = serializers.DateTimeField(required=False) + + # Publisher specific + publisher_type = serializers.CharField(required=False) # 'organization' or 'user' + published_datasets_count = serializers.IntegerField(required=False) + published_usecases_count = serializers.IntegerField(required=False) + members_count = serializers.IntegerField(required=False) + contributed_sectors_count = serializers.IntegerField(required=False) + homepage = serializers.CharField(required=False) + bio = serializers.CharField(required=False) + profile_picture = serializers.CharField(required=False) + username = serializers.CharField(required=False) + full_name = serializers.CharField(required=False) + class UnifiedSearch(APIView): """View for unified search across datasets, usecases, and aimodels.""" @@ -102,6 +132,22 @@ def _get_index_names(self, types_list: List[str]) -> List[str]: ) index_names.append(aimodel_index) + if "collaborative" in types_list: + collaborative_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.collaborative_document", "collaborative" + ) + index_names.append(collaborative_index) + + if "publisher" in types_list: + org_publisher_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.OrganizationPublisherDocument", + "organization_publisher", + ) + user_publisher_index = settings.ELASTICSEARCH_INDEX_NAMES.get( + "search.documents.publisher_document.UserPublisherDocument", "user_publisher" + ) + index_names.extend([org_publisher_index, user_publisher_index]) + return index_names def _build_unified_query(self, query: str) -> ESQ: @@ -114,16 +160,16 @@ def _build_unified_query(self, query: str) -> ESQ: ESQ( "multi_match", query=query, - fields=["title^3", "name^3", "display_name^3"], + fields=["title^3", "name^3", "display_name^3", "full_name^3"], fuzziness="AUTO", ), ESQ( "multi_match", query=query, - fields=["description^2", "summary^2"], + fields=["description^2", "summary^2", "bio^2"], fuzziness="AUTO", ), - ESQ("multi_match", query=query, fields=["tags^2"], fuzziness="AUTO"), + ESQ("multi_match", query=query, fields=["tags^2", "sectors^2"], fuzziness="AUTO"), ] # Type-specific nested queries @@ -172,6 +218,43 @@ def _build_unified_query(self, query: str) -> ESQ: ] ) + # Collaborative nested fields + common_queries.extend( + [ + ESQ( + "nested", + path="datasets", + query=ESQ( + "multi_match", + query=query, + fields=["datasets.title", "datasets.description"], + fuzziness="AUTO", + ), + ignore_unmapped=True, + ), + ESQ( + "nested", + path="use_cases", + query=ESQ( + "multi_match", + query=query, + fields=["use_cases.title", "use_cases.summary"], + fuzziness="AUTO", + ), + ignore_unmapped=True, + ), + ESQ( + "nested", + path="contributors", + query=ESQ( + "match", + **{"contributors.name": {"query": query, "fuzziness": "AUTO"}}, + ), + ignore_unmapped=True, + ), + ] + ) + # Organization and user (common across types) common_queries.extend( [ @@ -187,9 +270,7 @@ def _build_unified_query(self, query: str) -> ESQ: ESQ( "nested", path="user", - query=ESQ( - "match", **{"user.name": {"query": query, "fuzziness": "AUTO"}} - ), + query=ESQ("match", **{"user.name": {"query": query, "fuzziness": "AUTO"}}), ignore_unmapped=True, ), ] @@ -209,9 +290,7 @@ def _apply_filters(self, search: Search, filters: Dict[str, str]) -> Search: if "geographies" in filters: filter_values = filters["geographies"].split(",") - filter_values = Geography.get_geography_names_with_descendants( - filter_values - ) + filter_values = Geography.get_geography_names_with_descendants(filter_values) search = search.filter("terms", **{"geographies.raw": filter_values}) if "status" in filters: @@ -233,6 +312,10 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["type"] = "usecase" elif "aimodel" in index_name: result["type"] = "aimodel" + elif "collaborative" in index_name: + result["type"] = "collaborative" + elif "publisher" in index_name: + result["type"] = "publisher" else: result["type"] = "unknown" @@ -251,11 +334,30 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]: result["title"] = "" if "description" not in result: result["description"] = "" - # AIModel uses created_at/updated_at if "created_at" in result: result["created"] = result["created_at"] if "updated_at" in result: result["modified"] = result["updated_at"] + elif result["type"] == "collaborative": + if "summary" in result: + result["description"] = result.get("summary", "") + if "title" not in result: + result["title"] = "" + elif result["type"] == "publisher": + if "name" in result: + result["title"] = result.get("name", "") + if "bio" in result and result.get("bio"): + result["description"] = result.get("bio", "") + elif "description" not in result or not result.get("description"): + result["description"] = "" + if "status" not in result: + result["status"] = "active" + if "tags" not in result: + result["tags"] = [] + if "sectors" not in result or result.get("sectors") is None: + result["sectors"] = [] + if "geographies" not in result: + result["geographies"] = [] else: # dataset if "title" not in result: result["title"] = "" @@ -312,7 +414,6 @@ def perform_unified_search( if hasattr(response, "aggregations"): aggs_dict = response.aggregations.to_dict() - # Process types aggregation if "types" in aggs_dict: aggregations["types"] = {} for bucket in aggs_dict["types"]["buckets"]: @@ -323,37 +424,54 @@ def perform_unified_search( aggregations["types"]["usecase"] = bucket["doc_count"] elif "aimodel" in index_name: aggregations["types"]["aimodel"] = bucket["doc_count"] + elif "collaborative" in index_name: + aggregations["types"]["collaborative"] = bucket["doc_count"] + elif "publisher" in index_name: + # Combine both organization and user publisher counts + if "publisher" not in aggregations["types"]: + aggregations["types"]["publisher"] = 0 + aggregations["types"]["publisher"] += bucket["doc_count"] - # Process other aggregations for agg_name in ["tags", "sectors", "geographies", "status"]: if agg_name in aggs_dict: aggregations[agg_name] = {} for bucket in aggs_dict[agg_name]["buckets"]: aggregations[agg_name][bucket["key"]] = bucket["doc_count"] - total = ( - response.hits.total.value - if hasattr(response.hits.total, "value") - else len(results) - ) + total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results) return results, total, aggregations + def _generate_unified_cache_key(self, request: Any) -> str: + """Generate a unique cache key for unified search based on request parameters.""" + params: Dict[str, str] = { + "query": request.GET.get("query", ""), + "page": request.GET.get("page", "1"), + "size": request.GET.get("size", "10"), + "types": request.GET.get("types", "dataset,usecase,aimodel,collaborative,publisher"), + "filters": str(sorted(request.GET.dict().items())), + "version": str(cache.get(SEARCH_CACHE_VERSION_KEY, 0)), + } + return f"unified_search:{hash(frozenset(params.items()))}" + @trace_method(name="get", attributes={"component": "unified_search"}) def get(self, request: Any) -> Response: """Handle GET request and return unified search results.""" try: + cache_key = self._generate_unified_cache_key(request) + cached_result = cache.get(cache_key) + if cached_result: + return Response(cached_result) + query: str = request.GET.get("query", "") page: int = int(request.GET.get("page", 1)) size: int = int(request.GET.get("size", 10)) entity_types: str = request.GET.get( - "types", "dataset,usecase,aimodel" + "types", "dataset,usecase,aimodel,collaborative,publisher" ) # Which entity types to search - # Parse entity types types_list = [t.strip() for t in entity_types.split(",")] - # Handle filters filters: Dict[str, str] = {} for key, values in request.GET.lists(): if key not in ["query", "page", "size", "types"]: @@ -377,15 +495,15 @@ def get(self, request: Any) -> Response: "types_searched": types_list, } + cache.set(cache_key, result, timeout=3600) + return Response(result) except Exception as e: self.logger.error("unified_search_error", error=str(e), exc_info=True) return Response({"error": "An internal error has occurred."}, status=500) - def _build_aggregations( - self, results: List[Dict[str, Any]] - ) -> Dict[str, Dict[str, int]]: + def _build_aggregations(self, results: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: """Build aggregations from results.""" aggregations: Dict[str, Dict[str, int]] = { "types": {}, @@ -398,9 +516,7 @@ def _build_aggregations( for result in results: # Count by type result_type = result.get("type", "unknown") - aggregations["types"][result_type] = ( - aggregations["types"].get(result_type, 0) + 1 - ) + aggregations["types"][result_type] = aggregations["types"].get(result_type, 0) + 1 # Count by tags for tag in result.get("tags", []): @@ -408,9 +524,7 @@ def _build_aggregations( # Count by sectors for sector in result.get("sectors", []): - aggregations["sectors"][sector] = ( - aggregations["sectors"].get(sector, 0) + 1 - ) + aggregations["sectors"][sector] = aggregations["sectors"].get(sector, 0) + 1 # Count by geographies for geography in result.get("geographies", []): diff --git a/api/views/search_usecase.py b/api/views/search_usecase.py index f5ed330f..8782d40a 100644 --- a/api/views/search_usecase.py +++ b/api/views/search_usecase.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast import structlog +from django.core.cache import cache from elasticsearch_dsl import A from elasticsearch_dsl import Q as ESQ from elasticsearch_dsl import Search @@ -14,6 +15,8 @@ from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView from search.documents import UseCaseDocument +METADATA_CACHE_TTL = 60 * 30 # 30 minutes + logger = structlog.get_logger(__name__) @@ -144,9 +147,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.searchable_fields: List[str] self.aggregations: Dict[str, str] - self.searchable_fields, self.aggregations = ( - self.get_searchable_and_aggregations() - ) + self.searchable_fields, self.aggregations = self.get_searchable_and_aggregations() self.logger = structlog.get_logger(__name__) @trace_method( @@ -155,6 +156,12 @@ def __init__(self, **kwargs: Any) -> None: ) def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: """Get searchable fields and aggregations for the search.""" + cached: Optional[Tuple[List[str], Dict[str, str]]] = cache.get( + "usecase_search_metadata_config" + ) + if cached: + return cached + searchable_fields = [ "title", "summary", @@ -181,7 +188,9 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]: for metadata in filterable_metadata: aggregations[f"metadata.{metadata.label}"] = "terms" # type: ignore - return searchable_fields, aggregations + result = (searchable_fields, aggregations) + cache.set("usecase_search_metadata_config", result, timeout=METADATA_CACHE_TTL) + return result @trace_method(name="add_aggregations", attributes={"component": "search_usecase"}) def add_aggregations(self, search: Search) -> Search: @@ -199,18 +208,21 @@ def add_aggregations(self, search: Search) -> Search: ) if aggregate_fields: - metadata_qs = Metadata.objects.filter(filterable=True) - filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + filterable_metadata: List[str] = cache.get("usecase_filterable_metadata_labels") or [] + if not filterable_metadata: + metadata_qs = Metadata.objects.filter(filterable=True) + filterable_metadata = [str(meta.label) for meta in metadata_qs] # type: ignore + cache.set( + "usecase_filterable_metadata_labels", + filterable_metadata, + timeout=METADATA_CACHE_TTL, + ) metadata_bucket = search.aggs.bucket("metadata", "nested", path="metadata") composite_agg = A( "composite", sources=[ - { - "metadata_label": { - "terms": {"field": "metadata.metadata_item.label"} - } - }, + {"metadata_label": {"terms": {"field": "metadata.metadata_item.label"}}}, {"metadata_value": {"terms": {"field": "metadata.value"}}}, ], size=10000, @@ -219,13 +231,7 @@ def add_aggregations(self, search: Search) -> Search: "filter", { # type: ignore[arg-type] "bool": { - "must": [ - { - "terms": { - "metadata.metadata_item.label": filterable_metadata - } - } - ] + "must": [{"terms": {"metadata.metadata_item.label": filterable_metadata}}] } }, ) @@ -235,12 +241,8 @@ def add_aggregations(self, search: Search) -> Search: return search - @trace_method( - name="generate_q_expression", attributes={"component": "search_usecase"} - ) - def generate_q_expression( - self, query: str - ) -> Optional[Union[ESQuery, List[ESQuery]]]: + @trace_method(name="generate_q_expression", attributes={"component": "search_usecase"}) + def generate_q_expression(self, query: str) -> Optional[Union[ESQuery, List[ESQuery]]]: """Generate Elasticsearch Query expression.""" if query: queries: List[ESQuery] = [] @@ -256,9 +258,7 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), @@ -281,18 +281,14 @@ def generate_q_expression( ESQ("wildcard", **{field: {"value": f"*{query}*"}}), ESQ( "fuzzy", - **{ - field: {"value": query, "fuzziness": "AUTO"} - }, + **{field: {"value": query, "fuzziness": "AUTO"}}, ), ], ), ) ) else: - queries.append( - ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}}) - ) + queries.append(ESQ("fuzzy", **{field: {"value": query, "fuzziness": "AUTO"}})) else: queries = [ESQ("match_all")] @@ -301,8 +297,13 @@ def generate_q_expression( @trace_method(name="add_filters", attributes={"component": "search_usecase"}) def add_filters(self, filters: Dict[str, str], search: Search) -> Search: """Add filters to the search query.""" - non_filter_metadata = Metadata.objects.filter(filterable=False).all() - excluded_labels: List[str] = [e.label for e in non_filter_metadata] # type: ignore + excluded_labels: List[str] = cache.get("usecase_non_filter_metadata_labels") or [] + if not excluded_labels: + non_filter_metadata = Metadata.objects.filter(filterable=False).all() + excluded_labels = [e.label for e in non_filter_metadata] # type: ignore + cache.set( + "usecase_non_filter_metadata_labels", excluded_labels, timeout=METADATA_CACHE_TTL + ) for filter in filters: if filter in excluded_labels: @@ -335,9 +336,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search: search = search.filter( "nested", path="metadata", - query={ - "bool": {"must": {"term": {f"metadata.value": filters[filter]}}} - }, + query={"bool": {"must": {"term": {f"metadata.value": filters[filter]}}}}, ) return search diff --git a/authorization/authentication.py b/authorization/authentication.py index 49c6655e..c6aef186 100644 --- a/authorization/authentication.py +++ b/authorization/authentication.py @@ -5,7 +5,7 @@ from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User @@ -41,20 +41,37 @@ def authenticate(self, request: Request) -> Optional[Tuple[User, str]]: # Ensure we have a subject ID in the token if not user_info.get("sub"): - raise AuthenticationFailed( - "Token validation succeeded but missing subject ID" - ) - + raise AuthenticationFailed("Token validation succeeded but missing subject ID") + + # Check if this is a Django JWT (already validated user) or Keycloak token + # Django JWTs have 'user_id' in the payload, Keycloak tokens don't + from rest_framework_simplejwt.exceptions import InvalidToken, TokenError + from rest_framework_simplejwt.tokens import AccessToken + + try: + # Try to decode as Django JWT + access_token = AccessToken(token) # type: ignore[arg-type] + user_id = access_token.get("user_id") + + if user_id: + # This is a Django JWT - user is already synced, just get from DB + user = User.objects.get(id=user_id) + return (user, token) + except (TokenError, InvalidToken, User.DoesNotExist): + # Not a Django JWT or user not found, continue with Keycloak flow + pass + + # This is a Keycloak token - sync the user # Get user roles and organizations from the token roles = keycloak_manager.get_user_roles(token) organizations = keycloak_manager.get_user_organizations(token) # Sync the user information with our database - user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) - if not user: + synced_user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) + if not synced_user: raise AuthenticationFailed("Failed to synchronize user information") - return (user, token) + return (synced_user, token) def authenticate_header(self, request: Request) -> str: """ diff --git a/authorization/backends.py b/authorization/backends.py index 1da27967..13792a94 100644 --- a/authorization/backends.py +++ b/authorization/backends.py @@ -4,7 +4,7 @@ from django.contrib.auth.backends import ModelBackend from django.http import HttpRequest -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User @@ -14,10 +14,7 @@ class KeycloakAuthenticationBackend(ModelBackend): """ def authenticate( # type: ignore[override] - self, - request: Optional[HttpRequest] = None, - token: Optional[str] = None, - **kwargs: Any + self, request: Optional[HttpRequest] = None, token: Optional[str] = None, **kwargs: Any ) -> Optional[User]: """ Authenticate a user based on a Keycloak token. @@ -44,9 +41,7 @@ def authenticate( # type: ignore[override] # Get user roles and organizations from the token roles: List[str] = keycloak_manager.get_user_roles(token) - organizations: List[Dict[str, Any]] = keycloak_manager.get_user_organizations( - token - ) + organizations: List[Dict[str, Any]] = keycloak_manager.get_user_organizations(token) # Sync the user information with our database user: Optional[User] = keycloak_manager.sync_user_from_keycloak( diff --git a/authorization/keycloak.py b/authorization/keycloak.py index 6ea56e54..a42ce294 100644 --- a/authorization/keycloak.py +++ b/authorization/keycloak.py @@ -1,523 +1,6 @@ -from typing import Any, Dict, List, Optional, Type, TypeVar, cast +"""DEPRECATED: This module is deprecated. Use api.utils.keycloak_utils instead.""" -import structlog -from django.conf import settings -from django.db import transaction -from keycloak import KeycloakAdmin, KeycloakOpenID -from keycloak.exceptions import KeycloakError +# Import from the new location for backward compatibility +from api.utils.keycloak_utils import KeycloakManager, keycloak_manager -from api.models import Organization -from authorization.models import OrganizationMembership, Role, User - -logger = structlog.getLogger(__name__) - -# Type variables for model classes -T = TypeVar("T") - - -class KeycloakManager: - """ - Utility class to manage Keycloak integration with Django. - Handles token validation, user synchronization, and role mapping. - """ - - def __init__(self) -> None: - import structlog - - logger = structlog.getLogger(__name__) - - self.server_url: str = settings.KEYCLOAK_SERVER_URL - self.realm: str = settings.KEYCLOAK_REALM - self.client_id: str = settings.KEYCLOAK_CLIENT_ID - self.client_secret: str = settings.KEYCLOAK_CLIENT_SECRET - - # Log Keycloak connection details (without secrets) - logger.debug( - f"Initializing Keycloak connection to {self.server_url} " - f"for realm {self.realm} and client {self.client_id}" - ) - - try: - self.keycloak_openid: KeycloakOpenID = KeycloakOpenID( - server_url=self.server_url, - client_id=self.client_id, - realm_name=self.realm, - client_secret_key=self.client_secret, - ) - - logger.debug("Keycloak client initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize Keycloak client: {e}") - # Use cast to satisfy the type checker - self.keycloak_openid = cast(KeycloakOpenID, object()) - - def get_token(self, username: str, password: str) -> Dict[str, Any]: - """ - Get a Keycloak token for a user. - - Args: - username: The username - password: The password - - Returns: - Dict containing the token information - """ - try: - return self.keycloak_openid.token(username, password) - except KeycloakError as e: - logger.error(f"Error getting token: {e}") - raise - - def validate_token(self, token: str) -> Dict[str, Any]: - """ - Validate a Keycloak token and return the user info. - Only validates by contacting Keycloak directly - no local validation. - - Args: - token: The token to validate - - Returns: - Dict containing the user information or empty dict if validation fails - """ - import structlog - - logger = structlog.getLogger(__name__) - - # Log token for debugging - logger.debug(f"Validating token of length: {len(token)}") - - # Only try to contact Keycloak directly - don't create users from local token decoding - try: - logger.debug("Attempting to get user info from Keycloak") - user_info = self.keycloak_openid.userinfo(token) - if user_info and isinstance(user_info, dict): - logger.debug("Successfully retrieved user info from Keycloak") - logger.debug(f"User info: {user_info}") - return user_info - else: - logger.warning("Keycloak returned empty or invalid user info") - return {} - except Exception as e: - logger.warning(f"Failed to get user info from Keycloak: {e}") - return {} - - def get_user_roles(self, token: str) -> List[str]: - """ - Get the roles for a user from their token. - - Args: - token: The user's token - - Returns: - List of role names - """ - import structlog - from django.conf import settings - - logger = structlog.getLogger(__name__) - - # Get roles directly from token - logger.debug("Extracting roles from token") - - logger.debug(f"Getting roles from token of length: {len(token)}") - - try: - # Decode the token to get the roles - try: - token_info: Dict[str, Any] = self.keycloak_openid.decode_token(token) - logger.debug("Successfully decoded token for roles") - except Exception as decode_error: - logger.warning(f"Failed to decode token for roles: {decode_error}") - # If we can't decode the token, try to get roles from introspection - try: - token_info = self.keycloak_openid.introspect(token) - logger.debug("Using introspection result for roles") - except Exception as introspect_error: - logger.error( - f"Failed to introspect token for roles: {introspect_error}" - ) - return [] - - # Extract roles from token info - realm_access: Dict[str, Any] = token_info.get("realm_access", {}) - roles = cast(List[str], realm_access.get("roles", [])) - - # Also check resource_access for client roles - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - # Combine realm and client roles - all_roles = list(set(roles + client_roles)) - logger.debug(f"Found roles: {all_roles}") - - return all_roles - except KeycloakError as e: - logger.error(f"Error getting user roles: {e}") - return [] - except Exception as e: - logger.error(f"Unexpected error getting user roles: {e}") - return [] - - def get_user_organizations(self, token: str) -> List[Dict[str, Any]]: - """ - Get the organizations a user belongs to from their token. - This assumes that organization information is stored in the token - as client roles or in user attributes. - - Args: - token: The user's token - - Returns: - List of organization information - """ - import structlog - from django.conf import settings - - logger = structlog.getLogger(__name__) - - logger.debug(f"Getting organizations from token of length: {len(token)}") - - try: - # Decode the token to get user info - token_info = {} - try: - token_info = self.keycloak_openid.decode_token(token) - logger.debug("Successfully decoded token for organizations") - except Exception as decode_error: - logger.warning( - f"Failed to decode token for organizations: {decode_error}" - ) - # If we can't decode the token, try to get info from introspection - try: - token_info = self.keycloak_openid.introspect(token) - logger.debug("Using introspection result for organizations") - except Exception as introspect_error: - logger.error( - f"Failed to introspect token for organizations: {introspect_error}" - ) - return [] - - # Get organization info from resource_access or attributes - # This implementation depends on how organizations are represented in Keycloak - resource_access = token_info.get("resource_access", {}) - client_roles = resource_access.get(self.client_id, {}).get("roles", []) - - logger.debug(f"Found client roles: {client_roles}") - - # Extract organization info from roles - # Format could be 'org__' or similar - organizations = [] - for role in client_roles: - if role.startswith("org_"): - parts = role.split("_") - if len(parts) >= 3: - org_id = parts[1] - role_name = parts[2] - organizations.append( - {"organization_id": org_id, "role": role_name} - ) - - # If no organizations found through roles, check user attributes - if not organizations and token_info.get("attributes"): - attrs = token_info.get("attributes", {}) - org_attrs = attrs.get("organizations", []) - - if isinstance(org_attrs, str): - org_attrs = [org_attrs] # Convert single string to list - - for org_attr in org_attrs: - try: - # Format could be 'org_id:role' - org_id, role = org_attr.split(":") - organizations.append({"organization_id": org_id, "role": role}) - except ValueError: - # If no role specified, use default - organizations.append( - {"organization_id": org_attr, "role": "viewer"} - ) - - logger.debug(f"Found organizations: {organizations}") - return organizations - except KeycloakError as e: - logger.error(f"Error getting user organizations: {e}") - return [] - except Exception as e: - logger.error(f"Unexpected error getting user organizations: {e}") - return [] - - @transaction.atomic - def sync_user_from_keycloak( - self, - user_info: Dict[str, Any], - roles: List[str], - organizations: List[Dict[str, Any]], - ) -> Optional[User]: - """ - Synchronize user information from Keycloak to Django. - Creates or updates the User record and organization memberships. - - Args: - user_info: User information from Keycloak - roles: User roles from Keycloak (not used when maintaining roles in DB) - organizations: User organizations from Keycloak - - Returns: - The synchronized User object or None if failed - """ - import structlog - - logger = structlog.getLogger(__name__) - - # Log the user info we're trying to sync - logger.debug(f"Attempting to sync user with info: {user_info}") - - try: - # Extract key user information - keycloak_id = user_info.get("sub") - email = user_info.get("email") - username = user_info.get("preferred_username") or email - - # Validate required fields - if not keycloak_id or not username: - logger.error("Missing required user information from Keycloak") - return None - - # Initialize variables - user = None - created = False - - # Try to find user by keycloak_id first - try: - user = User.objects.get(keycloak_id=keycloak_id) - logger.debug(f"Found existing user by keycloak_id: {user.username}") - - # Update user details - user.username = str(username) if username else "" # type: ignore[assignment] - user.email = str(email) if email else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Try to find user by email - if email: - try: - user = User.objects.get(email=email) - logger.debug(f"Found existing user by email: {user.username}") - - # Update keycloak_id and other details - user.keycloak_id = str(keycloak_id) if keycloak_id else "" # type: ignore[assignment] - user.username = str(username) if username else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Try to find user by username - try: - user = User.objects.get(username=username) - logger.debug( - f"Found existing user by username: {user.username}" - ) - - # Update keycloak_id and other details - user.keycloak_id = str(keycloak_id) if keycloak_id else "" # type: ignore[assignment] - user.email = str(email) if email else "" # type: ignore[assignment] - user.first_name = ( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ) - user.last_name = ( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ) - user.is_active = True - user.save() - except User.DoesNotExist: - # Create new user - logger.debug( - f"Creating new user with keycloak_id: {keycloak_id}" - ) - user = User.objects.create( - keycloak_id=str(keycloak_id) if keycloak_id else "", # type: ignore[arg-type] - username=str(username) if username else "", # type: ignore[arg-type] - email=str(email) if email else "", # type: ignore[arg-type] - first_name=( - str(user_info.get("given_name", "")) - if user_info.get("given_name") - else "" - ), - last_name=( - str(user_info.get("family_name", "")) - if user_info.get("family_name") - else "" - ), - is_active=True, - ) - created = True - - # If this is a new user, we'll keep default permissions - if created: - pass - - if user is not None: # Check that user is not None before saving - user.save() - - # If this is a new user and we want to sync organization memberships - # We'll only create new memberships for organizations found in Keycloak - # but we won't update existing memberships or remove any - if user is not None and created and organizations: - # Process organizations from Keycloak - only for new users - for org_info in organizations: - org_id: Optional[str] = org_info.get("organization_id") - if not org_id: - continue - - # Try to get the organization - try: - organization: Organization = Organization.objects.get(id=org_id) - - # For new users, assign the default viewer role - # The actual role management will be done in the application - default_role: Role = Role.objects.get(name="viewer") - - # Create the organization membership with default role - # Only if it doesn't already exist - OrganizationMembership.objects.get_or_create( - user=user, - organization=organization, - defaults={"role": default_role}, - ) - except Organization.DoesNotExist as e: - logger.error( - f"Error processing organization from Keycloak: {e}" - ) - except Role.DoesNotExist as e: - logger.error(f"Default viewer role not found: {e}") - - # We don't remove organization memberships that are no longer in Keycloak - # since we're maintaining roles in the database - - return user - except Exception as e: - logger.error(f"Error synchronizing user from Keycloak: {e}") - return None - - def update_user_in_keycloak(self, user: User) -> bool: - """Update user details in Keycloak using admin credentials.""" - if not user.keycloak_id: - logger.warning( - "Cannot update user in Keycloak: No keycloak_id", user_id=str(user.id) - ) - return False - - try: - # Get admin credentials from settings - admin_username = getattr(settings, "KEYCLOAK_ADMIN_USERNAME", "") - admin_password = getattr(settings, "KEYCLOAK_ADMIN_PASSWORD", "") - - # Log credential presence (not the actual values) - logger.info( - "Admin credentials check", - username_present=bool(admin_username), - password_present=bool(admin_password), - ) - - if not admin_username or not admin_password: - logger.error("Keycloak admin credentials not configured") - return False - - from keycloak import KeycloakOpenID - - # First get an admin token directly - keycloak_openid = KeycloakOpenID( - server_url=self.server_url, - client_id="admin-cli", # Special client for admin operations - realm_name="master", # Admin users are in master realm - verify=True, - ) - - # Get token - try: - token = keycloak_openid.token( - username=admin_username, - password=admin_password, - grant_type="password", - ) - access_token = token.get("access_token") - - if not access_token: - logger.error("Failed to get admin access token") - return False - - # Now use the token to update the user - import requests - - headers = { - "Authorization": f"Bearer {access_token}", - "Content-Type": "application/json", - } - - user_data = { - "firstName": user.first_name, - "lastName": user.last_name, - "email": user.email, - "emailVerified": True, - } - - # Direct API call to update user - base_url = self.server_url.rstrip("/") # Remove any trailing slash - response = requests.put( - f"{base_url}/admin/realms/{self.realm}/users/{user.keycloak_id}", - headers=headers, - json=user_data, - ) - - if response.status_code == 204: # Success for this endpoint - logger.info( - "Successfully updated user in Keycloak", - user_id=str(user.id), - keycloak_id=user.keycloak_id, - ) - return True - else: - logger.error( - f"Failed to update user in Keycloak: {response.status_code}: {response.text}", - user_id=str(user.id), - ) - return False - - except Exception as token_error: - logger.error( - f"Error getting admin token: {str(token_error)}", - user_id=str(user.id), - ) - return False - - except Exception as e: - logger.error( - f"Error updating user in Keycloak: {str(e)}", user_id=str(user.id) - ) - return False - - -# Create a singleton instance -keycloak_manager = KeycloakManager() +__all__ = ["KeycloakManager", "keycloak_manager"] diff --git a/authorization/middleware_utils.py b/authorization/middleware_utils.py index fa70a622..57dbe36c 100644 --- a/authorization/middleware_utils.py +++ b/authorization/middleware_utils.py @@ -1,19 +1,31 @@ +import hashlib from typing import Any, Callable, Dict, List, Optional, cast import structlog from django.conf import settings from django.contrib.auth.models import AnonymousUser +from django.core.cache import cache from django.db import transaction from django.http import HttpRequest, HttpResponse from django.utils.functional import SimpleLazyObject from rest_framework.request import Request +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError +from rest_framework_simplejwt.tokens import AccessToken from api.utils.debug_utils import debug_auth_headers, debug_token_validation -from authorization.keycloak import keycloak_manager +from api.utils.keycloak_utils import keycloak_manager from authorization.models import User logger = structlog.getLogger(__name__) +TOKEN_CACHE_TTL = 60 * 5 # 5 minutes + + +def _get_token_cache_key(token: str) -> str: + """Return a safe Redis key derived from the token without storing the raw token.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + return f"auth_token:{token_hash}" + def get_user_from_keycloak_token(request: HttpRequest) -> User: """ @@ -50,15 +62,24 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: # Use the raw header value, but check if it might be a raw token # (no 'Bearer ' prefix but still a valid JWT format) token = auth_header - logger.debug( - f"Using raw Authorization header as token, length: {len(token)}" - ) + logger.debug(f"Using raw Authorization header as token, length: {len(token)}") # If no token found, return anonymous user if not token: logger.debug("No token found, returning anonymous user") return cast(User, AnonymousUser()) + # Check cache before doing any validation or DB work + cache_key = _get_token_cache_key(token) + cached_user_id = cache.get(cache_key) + if cached_user_id: + try: + user = User.objects.get(id=cached_user_id) + logger.debug(f"Returning cached authenticated user: {user.username}") + return user + except User.DoesNotExist: + cache.delete(cache_key) + # Log token details for debugging logger.debug(f"Processing token of length: {len(token)}") logger.debug(f"Token type: {type(token)}") @@ -70,7 +91,29 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: # For debugging, print the raw token logger.debug(f"Raw token: {token}") + # First, try to validate as Django JWT token + try: + logger.debug("Attempting to validate as Django JWT token") + access_token = AccessToken(token) + user_id = access_token.get("user_id") + + if user_id: + logger.debug(f"Valid Django JWT token for user_id: {user_id}") + try: + user = User.objects.get(id=user_id) + cache.set(cache_key, user.id, timeout=TOKEN_CACHE_TTL) + logger.debug(f"Successfully authenticated user via Django JWT: {user.username}") + return user + except User.DoesNotExist: + logger.warning(f"User with id {user_id} not found in database") + except (TokenError, InvalidToken) as e: + logger.debug(f"Not a valid Django JWT token: {e}, trying Keycloak validation") + except Exception as e: + logger.debug(f"Error validating Django JWT: {e}, trying Keycloak validation") + + # If Django JWT validation failed, try Keycloak token validation try: + logger.debug("Attempting to validate as Keycloak token") # Try direct validation without any complex logic user_info = keycloak_manager.validate_token(token) @@ -91,14 +134,10 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: return cast(User, AnonymousUser()) # Log the user info for debugging - logger.debug( - f"User info from token: {user_info.keys() if user_info else 'None'}" - ) + logger.debug(f"User info from token: {user_info.keys() if user_info else 'None'}") logger.debug(f"User sub: {user_info.get('sub', 'None')}") logger.debug(f"User email: {user_info.get('email', 'None')}") - logger.debug( - f"User preferred_username: {user_info.get('preferred_username', 'None')}" - ) + logger.debug(f"User preferred_username: {user_info.get('preferred_username', 'None')}") # Get user roles and organizations from the token roles = keycloak_manager.get_user_roles(token) @@ -108,18 +147,19 @@ def get_user_from_keycloak_token(request: HttpRequest) -> User: logger.debug(f"User organizations from token: {organizations}") # Sync the user information with our database - user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) - if not user: + synced_user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) + if not synced_user: logger.warning("User synchronization failed, returning anonymous user") return cast(User, AnonymousUser()) + cache.set(cache_key, synced_user.id, timeout=TOKEN_CACHE_TTL) logger.debug( - f"Successfully authenticated user: {user.username} (ID: {user.id})" + f"Successfully authenticated user: {synced_user.username} (ID: {synced_user.id})" ) # Return the authenticated user - logger.debug(f"Returning authenticated user: {user.username}") - return user + logger.debug(f"Returning authenticated user: {synced_user.username}") + return synced_user except Exception as e: logger.error(f"Error in get_user_from_keycloak_token: {str(e)}") return cast(User, AnonymousUser()) diff --git a/authorization/schema/mutation.py b/authorization/schema/mutation.py index ec229518..fdf51f0a 100644 --- a/authorization/schema/mutation.py +++ b/authorization/schema/mutation.py @@ -15,6 +15,7 @@ MutationResponse, ) from api.utils.graphql_telemetry import trace_resolver +from api.utils.keycloak_utils import keycloak_manager from authorization.models import OrganizationMembership, Role, User from authorization.permissions import IsAuthenticated from authorization.schema.inputs import ( @@ -34,12 +35,9 @@ @strawberry.type class Mutation: @strawberry.mutation(permission_classes=[IsAuthenticated]) - @trace_resolver( - name="update_user", attributes={"component": "user", "operation": "mutation"} - ) + @trace_resolver(name="update_user", attributes={"component": "user", "operation": "mutation"}) def update_user(self, info: Info, input: UpdateUserInput) -> TypeUser: """Update user details and sync with Keycloak.""" - from authorization.keycloak import keycloak_manager user = info.context.user @@ -139,9 +137,7 @@ def add_user_to_organization( user=user, organization=organization ) # If we get here, the membership exists - raise DjangoValidationError( - "User is already a member of this organization." - ) + raise DjangoValidationError("User is already a member of this organization.") except OrganizationMembership.DoesNotExist: # Membership doesn't exist, so create it membership = OrganizationMembership.objects.create( @@ -211,13 +207,9 @@ def remove_user_from_organization( organization = info.context.context.get("organization") # Check if the membership already exists - membership = OrganizationMembership.objects.get( - user=user, organization=organization - ) + membership = OrganizationMembership.objects.get(user=user, organization=organization) membership.delete() - return SuccessResponse( - success=True, message="User removed from organization" - ) + return SuccessResponse(success=True, message="User removed from organization") except User.DoesNotExist: raise DjangoValidationError(f"User with ID {input.user_id} does not exist.") except Role.DoesNotExist: @@ -270,8 +262,6 @@ def assign_dataset_permission( ) if result: - return SuccessResponse( - success=True, message="Permission assigned successfully" - ) + return SuccessResponse(success=True, message="Permission assigned successfully") else: return SuccessResponse(success=False, message="Failed to assign permission") diff --git a/authorization/services.py b/authorization/services.py index 7c3e8fb5..7ed117f8 100644 --- a/authorization/services.py +++ b/authorization/services.py @@ -27,11 +27,14 @@ def get_user_organizations(user_id: int) -> List[Dict[str, Any]]: List of dictionaries containing organization info and user's role """ # Use explicit type annotation for the queryset result - memberships = OrganizationMembership.objects.filter( - user_id=user_id - ).select_related( + memberships = OrganizationMembership.objects.filter(user_id=user_id).select_related( "organization", "role" ) # type: ignore[attr-defined] + + logger.info( + f"Getting organizations for user_id={user_id}, found {memberships.count()} memberships" + ) + return [ { "id": membership.organization.id, # type: ignore[attr-defined] @@ -59,9 +62,7 @@ def get_user_datasets(user_id: int) -> List[Dict[str, Any]]: List of dictionaries containing dataset info and user's role """ # Use explicit type annotation for the queryset result - dataset_permissions = DatasetPermission.objects.filter( - user_id=user_id - ).select_related( + dataset_permissions = DatasetPermission.objects.filter(user_id=user_id).select_related( "dataset", "role" ) # type: ignore[attr-defined] return [ @@ -80,9 +81,7 @@ def get_user_datasets(user_id: int) -> List[Dict[str, Any]]: ] @staticmethod - def check_organization_permission( - user_id: int, organization_id: int, operation: str - ) -> bool: + def check_organization_permission(user_id: int, organization_id: int, operation: str) -> bool: """ Check if a user has permission to perform an operation on an organization. @@ -120,9 +119,7 @@ def check_organization_permission( return False @staticmethod - def check_dataset_permission( - user_id: int, dataset_id: Union[int, str], operation: str - ) -> bool: + def check_dataset_permission(user_id: int, dataset_id: Union[int, str], operation: str) -> bool: """ Check if a user has permission to perform an operation on a dataset. Checks both organization-level and dataset-specific permissions. diff --git a/authorization/views.py b/authorization/views.py index 3c5aab93..ff20f9a3 100644 --- a/authorization/views.py +++ b/authorization/views.py @@ -6,8 +6,8 @@ from rest_framework.views import APIView from rest_framework_simplejwt.tokens import RefreshToken +from api.utils.keycloak_utils import keycloak_manager from authorization.consent import UserConsent -from authorization.keycloak import keycloak_manager from authorization.models import User from authorization.serializers import UserConsentSerializer from authorization.services import AuthorizationService @@ -62,9 +62,7 @@ def post(self, request: Request) -> Response: logger.info(f"Syncing user with Keycloak ID: {user_info.get('sub')}") user = keycloak_manager.sync_user_from_keycloak(user_info, roles, organizations) if not user: - logger.error( - f"Failed to sync user with Keycloak ID: {user_info.get('sub')}" - ) + logger.error(f"Failed to sync user with Keycloak ID: {user_info.get('sub')}") return Response( {"error": "Failed to synchronize user information"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -176,9 +174,7 @@ def put(self, request: Request) -> Response: user=request.user, activity_tracking_enabled=False # type: ignore[misc] ) - serializer = UserConsentSerializer( - consent, data=request.data, context={"request": request} - ) + serializer = UserConsentSerializer(consent, data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() return Response(serializer.data) diff --git a/dataspace_sdk/__init__.py b/dataspace_sdk/__init__.py index bb813609..1848fc94 100644 --- a/dataspace_sdk/__init__.py +++ b/dataspace_sdk/__init__.py @@ -1,5 +1,6 @@ """DataSpace Python SDK for programmatic access to DataSpace resources.""" +from dataspace_sdk.__version__ import __version__ from dataspace_sdk.client import DataSpaceClient from dataspace_sdk.exceptions import ( DataSpaceAPIError, @@ -8,7 +9,6 @@ DataSpaceValidationError, ) -__version__ = "0.1.0" __all__ = [ "DataSpaceClient", "DataSpaceAPIError", diff --git a/dataspace_sdk/__version__.py b/dataspace_sdk/__version__.py new file mode 100644 index 00000000..cff36390 --- /dev/null +++ b/dataspace_sdk/__version__.py @@ -0,0 +1,3 @@ +"""Version information for DataSpace SDK.""" + +__version__ = "0.4.19" diff --git a/dataspace_sdk/auth.py b/dataspace_sdk/auth.py index 372f5d18..4ecac6e2 100644 --- a/dataspace_sdk/auth.py +++ b/dataspace_sdk/auth.py @@ -1,5 +1,6 @@ """Authentication module for DataSpace SDK.""" +import time from typing import Any, Dict, Optional import requests @@ -10,19 +11,310 @@ class AuthClient: """Handles authentication with DataSpace API.""" - def __init__(self, base_url: str): + def __init__( + self, + base_url: str, + keycloak_url: Optional[str] = None, + keycloak_realm: Optional[str] = None, + keycloak_client_id: Optional[str] = None, + keycloak_client_secret: Optional[str] = None, + ): """ Initialize the authentication client. Args: base_url: Base URL of the DataSpace API + keycloak_url: Keycloak server URL (e.g., "https://opub-kc.civicdatalab.in") + keycloak_realm: Keycloak realm name (e.g., "DataSpace") + keycloak_client_id: Keycloak client ID (e.g., "dataspace") + keycloak_client_secret: Optional client secret for confidential clients """ self.base_url = base_url.rstrip("/") + self.keycloak_url = keycloak_url.rstrip("/") if keycloak_url else None + self.keycloak_realm = keycloak_realm + self.keycloak_client_id = keycloak_client_id + self.keycloak_client_secret = keycloak_client_secret + + # Session state self.access_token: Optional[str] = None self.refresh_token: Optional[str] = None + self.keycloak_access_token: Optional[str] = None + self.keycloak_refresh_token: Optional[str] = None + self.token_expires_at: Optional[float] = None self.user_info: Optional[Dict] = None - def login_with_keycloak(self, keycloak_token: str) -> Dict[str, Any]: + # Stored credentials for auto-relogin + self._username: Optional[str] = None + self._password: Optional[str] = None + + def login(self, username: str, password: str) -> Dict[str, Any]: + """ + Login using username and password via Keycloak. + + Args: + username: User's username or email + password: User's password + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + """ + if not all([self.keycloak_url, self.keycloak_realm, self.keycloak_client_id]): + raise DataSpaceAuthError( + "Keycloak configuration missing. Please provide keycloak_url, " + "keycloak_realm, and keycloak_client_id when initializing the client." + ) + + # Store credentials for auto-relogin + self._username = username + self._password = password + + # Get Keycloak token + keycloak_token = self._get_keycloak_token(username, password) + + # Login to DataSpace backend + return self._login_with_keycloak_token(keycloak_token) + + def login_as_service_account(self) -> Dict[str, Any]: + """ + Login using client credentials (service account). + + This method authenticates the client itself (not a user) using + the client_id and client_secret. Requires the Keycloak client + to have "Service Accounts Enabled". + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + """ + if not all( + [ + self.keycloak_url, + self.keycloak_realm, + self.keycloak_client_id, + self.keycloak_client_secret, + ] + ): + raise DataSpaceAuthError( + "Service account authentication requires keycloak_url, " + "keycloak_realm, keycloak_client_id, and keycloak_client_secret." + ) + + # Get Keycloak token using client credentials + keycloak_token = self._get_service_account_token() + + # Login to DataSpace backend + return self._login_with_keycloak_token(keycloak_token) + + def _get_keycloak_token(self, username: str, password: str) -> str: + """ + Get Keycloak access token using username and password. + + Args: + username: User's username or email + password: User's password + + Returns: + Keycloak access token + + Raises: + DataSpaceAuthError: If authentication fails + """ + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "password", + "client_id": self.keycloak_client_id, + "username": username, + "password": password, + } + + if self.keycloak_client_secret: + data["client_secret"] = self.keycloak_client_secret + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + # Calculate token expiration time + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in Keycloak response") + + return self.keycloak_access_token + else: + error_data = response.json() + error_msg = error_data.get( + "error_description", + error_data.get("error", "Keycloak authentication failed"), + ) + raise DataSpaceAuthError( + f"Keycloak login failed: {error_msg}", + status_code=response.status_code, + response=error_data, + ) + except requests.RequestException as e: + raise DataSpaceAuthError(f"Network error during Keycloak authentication: {str(e)}") + + def _get_service_account_token(self) -> str: + """ + Get Keycloak access token using client credentials (service account). + + Returns: + Keycloak access token + + Raises: + DataSpaceAuthError: If authentication fails + """ + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "client_credentials", + "client_id": self.keycloak_client_id, + "client_secret": self.keycloak_client_secret, + } + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + # Calculate token expiration time + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in Keycloak response") + + return self.keycloak_access_token + else: + error_data = response.json() + error_msg = error_data.get( + "error_description", + error_data.get("error", "Service account authentication failed"), + ) + raise DataSpaceAuthError( + f"Service account login failed: {error_msg}. " + f"Ensure 'Service Accounts Enabled' is ON in Keycloak client settings.", + status_code=response.status_code, + response=error_data, + ) + except requests.RequestException as e: + raise DataSpaceAuthError( + f"Network error during service account authentication: {str(e)}" + ) + + def _refresh_keycloak_token(self) -> str: + """ + Refresh Keycloak access token using refresh token. + + Returns: + New Keycloak access token + + Raises: + DataSpaceAuthError: If token refresh fails + """ + if not self.keycloak_refresh_token: + # If no refresh token, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError("No refresh token or credentials available") + + token_url = ( + f"{self.keycloak_url}/auth/realms/{self.keycloak_realm}/" + f"protocol/openid-connect/token" + ) + + data = { + "grant_type": "refresh_token", + "client_id": self.keycloak_client_id, + "refresh_token": self.keycloak_refresh_token, + } + + if self.keycloak_client_secret: + data["client_secret"] = self.keycloak_client_secret + + try: + response = requests.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code == 200: + token_data = response.json() + self.keycloak_access_token = token_data.get("access_token") + self.keycloak_refresh_token = token_data.get("refresh_token") + + expires_in = token_data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in + + if not self.keycloak_access_token: + raise DataSpaceAuthError("No access token in refresh response") + + return self.keycloak_access_token + else: + # Refresh failed, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError("Keycloak token refresh failed") + except requests.RequestException as e: + # Network error, try to relogin with stored credentials + if self._username and self._password: + return self._get_keycloak_token(self._username, self._password) + raise DataSpaceAuthError(f"Network error during token refresh: {str(e)}") + + def _ensure_valid_keycloak_token(self) -> str: + """ + Ensure we have a valid Keycloak token, refreshing if necessary. + + Returns: + Valid Keycloak access token + + Raises: + DataSpaceAuthError: If unable to get valid token + """ + # Check if token is expired or about to expire (within 30 seconds) + if ( + not self.keycloak_access_token + or not self.token_expires_at + or time.time() >= (self.token_expires_at - 30) + ): + # Token expired or about to expire, refresh it + if self.keycloak_refresh_token or (self._username and self._password): + return self._refresh_keycloak_token() + raise DataSpaceAuthError("No valid token or credentials available") + + return self.keycloak_access_token + + def _login_with_keycloak_token(self, keycloak_token: str) -> Dict[str, Any]: """ Login using a Keycloak token. @@ -140,3 +432,39 @@ def _get_auth_headers(self) -> Dict[str, str]: def is_authenticated(self) -> bool: """Check if the client is authenticated.""" return self.access_token is not None + + def ensure_authenticated(self) -> None: + """ + Ensure the client is authenticated, attempting auto-relogin if needed. + + Raises: + DataSpaceAuthError: If unable to authenticate + """ + if not self.is_authenticated(): + # Try to relogin with stored credentials + if self._username and self._password: + self.login(self._username, self._password) + else: + raise DataSpaceAuthError("Not authenticated. Please call login() first.") + + def get_valid_token(self) -> str: + """ + Get a valid access token, refreshing if necessary. + + Returns: + Valid access token + + Raises: + DataSpaceAuthError: If unable to get valid token + """ + # First ensure we have a valid Keycloak token + if self.keycloak_url and self.keycloak_realm: + keycloak_token = self._ensure_valid_keycloak_token() + # Re-login to backend with fresh Keycloak token if needed + if not self.access_token: + self._login_with_keycloak_token(keycloak_token) + + if not self.access_token: + raise DataSpaceAuthError("No access token available") + + return self.access_token diff --git a/dataspace_sdk/base.py b/dataspace_sdk/base.py index bdf38e16..9f53e2da 100644 --- a/dataspace_sdk/base.py +++ b/dataspace_sdk/base.py @@ -25,6 +25,7 @@ def __init__(self, base_url: str, auth_client: Any = None): """ self.base_url = base_url.rstrip("/") self.auth_client = auth_client + self.default_headers: Dict[str, str] = {} def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: """ @@ -38,6 +39,9 @@ def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> D """ headers = {"Content-Type": "application/json"} + if self.default_headers: + headers.update(self.default_headers) + if self.auth_client and self.auth_client.is_authenticated(): headers["Authorization"] = f"Bearer {self.auth_client.access_token}" diff --git a/dataspace_sdk/client.py b/dataspace_sdk/client.py index 93b3e50c..5f85eb28 100644 --- a/dataspace_sdk/client.py +++ b/dataspace_sdk/client.py @@ -4,7 +4,9 @@ from dataspace_sdk.auth import AuthClient from dataspace_sdk.resources.aimodels import AIModelClient +from dataspace_sdk.resources.auditors import AuditorClient from dataspace_sdk.resources.datasets import DatasetClient +from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient @@ -33,24 +35,97 @@ class DataSpaceClient: >>> org_usecases = client.usecases.get_organization_usecases("org-uuid") """ - def __init__(self, base_url: str): + def __init__( + self, + base_url: str, + keycloak_url: Optional[str] = None, + keycloak_realm: Optional[str] = None, + keycloak_client_id: Optional[str] = None, + keycloak_client_secret: Optional[str] = None, + ): """ Initialize the DataSpace client. Args: base_url: Base URL of the DataSpace API (e.g., "https://api.dataspace.example.com") + keycloak_url: Keycloak server URL (e.g., "https://opub-kc.civicdatalab.in") + keycloak_realm: Keycloak realm name (e.g., "DataSpace") + keycloak_client_id: Keycloak client ID (e.g., "dataspace") + keycloak_client_secret: Optional client secret for confidential clients """ self.base_url = base_url.rstrip("/") - self._auth = AuthClient(self.base_url) + self._auth = AuthClient( + self.base_url, + keycloak_url=keycloak_url, + keycloak_realm=keycloak_realm, + keycloak_client_id=keycloak_client_id, + keycloak_client_secret=keycloak_client_secret, + ) # Initialize resource clients self.datasets = DatasetClient(self.base_url, self._auth) self.aimodels = AIModelClient(self.base_url, self._auth) self.usecases = UseCaseClient(self.base_url, self._auth) + self.sectors = SectorClient(self.base_url, self._auth) + self.auditors = AuditorClient(self.base_url, self._auth) - def login(self, keycloak_token: str) -> dict: + def login(self, username: str, password: str) -> dict: """ - Login using a Keycloak token. + Login using username and password. + + Args: + username: User's username or email + password: User's password + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + + Example: + >>> client = DataSpaceClient( + ... base_url="https://api.dataspace.example.com", + ... keycloak_url="https://opub-kc.civicdatalab.in", + ... keycloak_realm="DataSpace", + ... keycloak_client_id="dataspace" + ... ) + >>> user_info = client.login(username="user@example.com", password="secret") + >>> print(user_info["user"]["username"]) + """ + return self._auth.login(username, password) + + def login_as_service_account(self) -> dict: + """ + Login using client credentials (service account). + + This method authenticates the client itself using client_id and client_secret. + The Keycloak client must have "Service Accounts Enabled" turned ON. + + This is the recommended approach for backend services and automated tasks. + + Returns: + Dictionary containing user info and tokens + + Raises: + DataSpaceAuthError: If authentication fails + + Example: + >>> client = DataSpaceClient( + ... base_url="https://api.dataspace.example.com", + ... keycloak_url="https://opub-kc.civicdatalab.in", + ... keycloak_realm="DataSpace", + ... keycloak_client_id="dataspace", + ... keycloak_client_secret="your-secret" + ... ) + >>> info = client.login_as_service_account() + >>> print("Authenticated as service account") + """ + return self._auth.login_as_service_account() + + def login_with_token(self, keycloak_token: str) -> dict: + """ + Login using a pre-obtained Keycloak token. Args: keycloak_token: Valid Keycloak access token @@ -63,10 +138,10 @@ def login(self, keycloak_token: str) -> dict: Example: >>> client = DataSpaceClient(base_url="https://api.dataspace.example.com") - >>> user_info = client.login(keycloak_token="your_token") + >>> user_info = client.login_with_token(keycloak_token="your_token") >>> print(user_info["user"]["username"]) """ - return self._auth.login_with_keycloak(keycloak_token) + return self._auth._login_with_keycloak_token(keycloak_token) def refresh_token(self) -> str: """ @@ -112,6 +187,22 @@ def is_authenticated(self) -> bool: """ return self._auth.is_authenticated() + def set_organization(self, organization_id: str) -> None: + """ + Set the organization header for all subsequent API requests. + + The DataSpace backend reads the 'organization' header to scope + queries to a specific organization. + + Args: + organization_id: Organization ID to include in requests + """ + self.datasets.default_headers["organization"] = organization_id + self.aimodels.default_headers["organization"] = organization_id + self.usecases.default_headers["organization"] = organization_id + self.sectors.default_headers["organization"] = organization_id + self.auditors.default_headers["organization"] = organization_id + @property def user(self) -> Optional[dict]: """ diff --git a/dataspace_sdk/resources/__init__.py b/dataspace_sdk/resources/__init__.py index dc02cf6e..e0f64724 100644 --- a/dataspace_sdk/resources/__init__.py +++ b/dataspace_sdk/resources/__init__.py @@ -1,7 +1,9 @@ """Resource clients for DataSpace SDK.""" from dataspace_sdk.resources.aimodels import AIModelClient +from dataspace_sdk.resources.auditors import AuditorClient from dataspace_sdk.resources.datasets import DatasetClient +from dataspace_sdk.resources.sectors import SectorClient from dataspace_sdk.resources.usecases import UseCaseClient -__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient"] +__all__ = ["DatasetClient", "AIModelClient", "UseCaseClient", "SectorClient", "AuditorClient"] diff --git a/dataspace_sdk/resources/aimodels.py b/dataspace_sdk/resources/aimodels.py index 0b0d56e8..9dd10c93 100644 --- a/dataspace_sdk/resources/aimodels.py +++ b/dataspace_sdk/resources/aimodels.py @@ -17,6 +17,7 @@ def search( status: Optional[str] = None, model_type: Optional[str] = None, provider: Optional[str] = None, + domain: Optional[str] = None, sort: Optional[str] = None, page: int = 1, page_size: int = 10, @@ -32,6 +33,7 @@ def search( status: Filter by status (ACTIVE, INACTIVE, etc.) model_type: Filter by model type (LLM, VISION, etc.) provider: Filter by provider (OPENAI, ANTHROPIC, etc.) + domain: Filter by domain (HEALTHCARE, EDUCATION, etc.) sort: Sort order (recent, alphabetical) page: Page number (1-indexed) page_size: Number of results per page @@ -58,10 +60,12 @@ def search( params["model_type"] = model_type if provider: params["provider"] = provider + if domain: + params["domain"] = domain if sort: params["sort"] = sort - return self.get("/api/search/aimodel/", params=params) + return super().get("/api/search/aimodel/", params=params) def get_by_id(self, model_id: str) -> Dict[str, Any]: """ @@ -73,7 +77,8 @@ def get_by_id(self, model_id: str) -> Dict[str, Any]: Returns: Dictionary containing AI model information """ - return self.get(f"/api/aimodels/{model_id}/") + # Use parent class get method with full endpoint path + return super().get(f"/api/aimodels/{model_id}/") def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: """ @@ -86,26 +91,14 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: Dictionary containing AI model information """ query = """ - query GetAIModel($id: UUID!) { - aiModel(id: $id) { + query GetAIModel($id: Int!) { + getAiModel(modelId: $id) { id name displayName description modelType - provider - version - providerModelId - hfUsePipeline - hfAuthToken - hfModelClass - hfAttnImplementation - framework - supportsStreaming - maxTokens - supportedLanguages - inputSchema - outputSchema + domain status isPublic createdAt @@ -126,13 +119,49 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: id name } - endpoints { + versions { id - name - url - httpMethod - authType - isActive + version + versionNotes + lifecycleStage + isLatest + supportsStreaming + maxTokens + supportedLanguages + inputSchema + outputSchema + status + createdAt + updatedAt + publishedAt + providers { + id + provider + providerModelId + isPrimary + isActive + # API Configuration + apiEndpointUrl + apiHttpMethod + apiTimeoutSeconds + apiAuthType + apiAuthHeaderName + apiKey + apiKeyPrefix + apiHeaders + apiRequestTemplate + apiResponsePath + # HuggingFace Configuration + hfUsePipeline + hfAuthToken + hfModelClass + hfAttnImplementation + hfTrustRemoteCode + hfTorchDtype + hfDeviceMap + framework + config + } } } } @@ -142,7 +171,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: "/api/graphql", json_data={ "query": query, - "variables": {"id": model_id}, + "variables": {"id": int(model_id)}, }, ) @@ -151,7 +180,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("aiModel", {}) + result: Dict[str, Any] = response.get("data", {}).get("getAiModel", {}) return result def list_all( @@ -183,8 +212,7 @@ def list_all( displayName description modelType - provider - version + domain status isPublic createdAt @@ -197,6 +225,28 @@ def list_all( id value } + sectors { + id + name + slug + } + geographies { + id + name + } + versions { + id + version + lifecycleStage + isLatest + status + providers { + id + provider + providerModelId + isPrimary + } + } } } """ @@ -293,7 +343,11 @@ def delete_model(self, model_id: str) -> Dict[str, Any]: return self.delete(f"/api/aimodels/{model_id}/") def call_model( - self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None + self, + model_id: str, + input_text: str, + parameters: Optional[Dict[str, Any]] = None, + version_id: Optional[int] = None, ) -> Dict[str, Any]: """ Call an AI model with input text using the appropriate client (API or HuggingFace). @@ -302,6 +356,7 @@ def call_model( model_id: UUID of the AI model input_text: Input text to process parameters: Optional parameters for the model call (temperature, max_tokens, etc.) + version_id: Optional specific version ID to call (defaults to primary/latest version) Returns: Dictionary containing model response: @@ -314,13 +369,24 @@ def call_model( ... } """ + payload: Dict[str, Any] = { + "input_text": input_text, + "parameters": parameters or {}, + } + if version_id is not None: + payload["version_id"] = version_id + return self.post( f"/api/aimodels/{model_id}/call/", - json_data={"input_text": input_text, "parameters": parameters or {}}, + json_data=payload, ) def call_model_async( - self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None + self, + model_id: str, + input_text: str, + parameters: Optional[Dict[str, Any]] = None, + version_id: Optional[int] = None, ) -> Dict[str, Any]: """ Call an AI model asynchronously (returns task ID for long-running operations). @@ -329,6 +395,7 @@ def call_model_async( model_id: UUID of the AI model input_text: Input text to process parameters: Optional parameters for the model call + version_id: Optional specific version ID to call (defaults to primary/latest version) Returns: Dictionary containing task information: @@ -338,7 +405,624 @@ def call_model_async( "model_id": str } """ + payload: Dict[str, Any] = { + "input_text": input_text, + "parameters": parameters or {}, + } + if version_id is not None: + payload["version_id"] = version_id + return self.post( f"/api/aimodels/{model_id}/call-async/", - json_data={"input_text": input_text, "parameters": parameters or {}}, + json_data=payload, + ) + + # ==================== Version Management ==================== + + def get_versions(self, model_id: int) -> List[Dict[str, Any]]: + """ + Get all versions for an AI model. + + Args: + model_id: ID of the AI model + + Returns: + List of version dictionaries + """ + query = """ + query GetModelVersions($filters: AIModelFilter) { + aiModels(filters: $filters) { + versions { + id + version + versionNotes + lifecycleStage + isLatest + supportsStreaming + maxTokens + supportedLanguages + status + createdAt + updatedAt + publishedAt + providers { + id + provider + providerModelId + isPrimary + isActive + } + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"filters": {"id": model_id}}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + models = response.get("data", {}).get("aiModels", []) + if models: + result: List[Dict[str, Any]] = models[0].get("versions", []) + return result + return [] + + def create_version( + self, + model_id: int, + version: str, + lifecycle_stage: str = "DEVELOPMENT", + is_latest: bool = False, + copy_from_version_id: Optional[int] = None, + version_notes: Optional[str] = None, + supports_streaming: bool = False, + max_tokens: Optional[int] = None, + supported_languages: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Create a new version for an AI model. + + Args: + model_id: ID of the AI model + version: Version string (e.g., "1.0", "2.1") + lifecycle_stage: One of DEVELOPMENT, TESTING, BETA, STAGING, PRODUCTION, DEPRECATED, RETIRED + is_latest: Whether this should be the primary version + copy_from_version_id: Optional version ID to copy providers from + version_notes: Optional notes about this version + supports_streaming: Whether this version supports streaming + max_tokens: Maximum tokens supported + supported_languages: List of supported language codes + + Returns: + Dictionary containing created version information + """ + mutation = """ + mutation CreateAIModelVersion($input: CreateAIModelVersionInput!) { + createAiModelVersion(input: $input) { + success + data { + id + version + lifecycleStage + isLatest + status + } + errors + } + } + """ + + input_data: Dict[str, Any] = { + "modelId": model_id, + "version": version, + "lifecycleStage": lifecycle_stage, + "isLatest": is_latest, + "supportsStreaming": supports_streaming, + } + + if copy_from_version_id: + input_data["copyFromVersionId"] = copy_from_version_id + if version_notes: + input_data["versionNotes"] = version_notes + if max_tokens: + input_data["maxTokens"] = max_tokens + if supported_languages: + input_data["supportedLanguages"] = supported_languages + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("createAiModelVersion", {}) + return result + + def update_version( + self, + version_id: int, + version: Optional[str] = None, + lifecycle_stage: Optional[str] = None, + is_latest: Optional[bool] = None, + version_notes: Optional[str] = None, + status: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Update an AI model version. + + Args: + version_id: ID of the version to update + version: New version string + lifecycle_stage: New lifecycle stage + is_latest: Whether this should be the primary version + version_notes: New version notes + status: New status + + Returns: + Dictionary containing updated version information + """ + mutation = """ + mutation UpdateAIModelVersion($input: UpdateAIModelVersionInput!) { + updateAiModelVersion(input: $input) { + success + data { + id + version + lifecycleStage + isLatest + status + } + errors + } + } + """ + + input_data: Dict[str, Any] = {"id": version_id} + + if version is not None: + input_data["version"] = version + if lifecycle_stage is not None: + input_data["lifecycleStage"] = lifecycle_stage + if is_latest is not None: + input_data["isLatest"] = is_latest + if version_notes is not None: + input_data["versionNotes"] = version_notes + if status is not None: + input_data["status"] = status + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updateAiModelVersion", {}) + return result + + # ==================== Provider Management ==================== + + def get_version_providers(self, version_id: int) -> List[Dict[str, Any]]: + """ + Get all providers for a specific version. + + Args: + version_id: ID of the version + + Returns: + List of provider dictionaries + """ + query = """ + query GetVersionProviders($versionId: Int!) { + aiModelVersion(id: $versionId) { + providers { + id + provider + providerModelId + isPrimary + isActive + # API Configuration + apiEndpointUrl + apiHttpMethod + apiTimeoutSeconds + apiAuthType + apiAuthHeaderName + apiKey + apiKeyPrefix + apiHeaders + apiRequestTemplate + apiResponsePath + # HuggingFace Configuration + hfUsePipeline + hfAuthToken + hfModelClass + hfAttnImplementation + hfTrustRemoteCode + hfTorchDtype + hfDeviceMap + framework + config + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"versionId": version_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + version_data = response.get("data", {}).get("aiModelVersion", {}) + result: List[Dict[str, Any]] = version_data.get("providers", []) if version_data else [] + return result + + def create_provider( + self, + version_id: int, + provider: str, + provider_model_id: str, + is_primary: bool = False, + # API Configuration + api_endpoint_url: Optional[str] = None, + api_http_method: str = "POST", + api_timeout_seconds: int = 60, + api_auth_type: str = "BEARER", + api_auth_header_name: str = "Authorization", + api_key: Optional[str] = None, + api_key_prefix: str = "Bearer", + api_headers: Optional[Dict[str, str]] = None, + api_request_template: Optional[Dict[str, Any]] = None, + api_response_path: Optional[str] = None, + # HuggingFace Configuration + hf_use_pipeline: bool = False, + hf_model_class: Optional[str] = None, + hf_auth_token: Optional[str] = None, + hf_attn_implementation: Optional[str] = None, + hf_trust_remote_code: bool = True, + hf_torch_dtype: Optional[str] = "auto", + hf_device_map: Optional[str] = "auto", + framework: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Create a new provider for a version. + + Args: + version_id: ID of the version + provider: Provider type (OPENAI, LLAMA_OLLAMA, LLAMA_TOGETHER, LLAMA_REPLICATE, + LLAMA_CUSTOM, CUSTOM, HUGGINGFACE) + provider_model_id: Model ID at the provider (e.g., "gpt-4", "meta-llama/Llama-2-7b") + is_primary: Whether this is the primary provider + api_endpoint_url: Full URL for the API endpoint + api_http_method: HTTP method (POST, GET) + api_timeout_seconds: Request timeout in seconds + api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE) + api_auth_header_name: Header name for authentication + api_key: API key or token + api_key_prefix: Prefix for the API key (e.g., "Bearer") + api_headers: Additional headers as dict + api_request_template: Request body template as dict + api_response_path: JSON path to extract response text + hf_use_pipeline: For HuggingFace - whether to use pipeline API + hf_model_class: For HuggingFace - model class (e.g., "AutoModelForCausalLM") + hf_auth_token: For HuggingFace - auth token for gated models + hf_attn_implementation: For HuggingFace - attention implementation + hf_trust_remote_code: For HuggingFace - trust remote code + hf_torch_dtype: For HuggingFace - torch dtype (auto, float16, bfloat16) + hf_device_map: For HuggingFace - device map (auto, cuda, cpu) + framework: Framework (pt, tf) + config: Additional configuration + + Returns: + Dictionary containing created provider information + """ + mutation = """ + mutation CreateVersionProvider($input: CreateVersionProviderInput!) { + createVersionProvider(input: $input) { + success + data { + id + provider + providerModelId + isPrimary + isActive + } + errors + } + } + """ + + input_data: Dict[str, Any] = { + "versionId": version_id, + "provider": provider, + "providerModelId": provider_model_id, + "isPrimary": is_primary, + # API Configuration + "apiHttpMethod": api_http_method, + "apiTimeoutSeconds": api_timeout_seconds, + "apiAuthType": api_auth_type, + "apiAuthHeaderName": api_auth_header_name, + "apiKeyPrefix": api_key_prefix, + # HuggingFace Configuration + "hfUsePipeline": hf_use_pipeline, + "hfTrustRemoteCode": hf_trust_remote_code, + } + + # Optional API fields + if api_endpoint_url: + input_data["apiEndpointUrl"] = api_endpoint_url + if api_key: + input_data["apiKey"] = api_key + if api_headers: + input_data["apiHeaders"] = api_headers + if api_request_template: + input_data["apiRequestTemplate"] = api_request_template + if api_response_path: + input_data["apiResponsePath"] = api_response_path + + # Optional HuggingFace fields + if hf_model_class: + input_data["hfModelClass"] = hf_model_class + if hf_auth_token: + input_data["hfAuthToken"] = hf_auth_token + if hf_attn_implementation: + input_data["hfAttnImplementation"] = hf_attn_implementation + if hf_torch_dtype: + input_data["hfTorchDtype"] = hf_torch_dtype + if hf_device_map: + input_data["hfDeviceMap"] = hf_device_map + if framework: + input_data["framework"] = framework + if config: + input_data["config"] = config + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("createVersionProvider", {}) + return result + + def update_provider( + self, + provider_id: int, + provider_model_id: Optional[str] = None, + is_primary: Optional[bool] = None, + # API Configuration + api_endpoint_url: Optional[str] = None, + api_http_method: Optional[str] = None, + api_timeout_seconds: Optional[int] = None, + api_auth_type: Optional[str] = None, + api_auth_header_name: Optional[str] = None, + api_key: Optional[str] = None, + api_key_prefix: Optional[str] = None, + api_headers: Optional[Dict[str, str]] = None, + api_request_template: Optional[Dict[str, Any]] = None, + api_response_path: Optional[str] = None, + # HuggingFace Configuration + hf_use_pipeline: Optional[bool] = None, + hf_model_class: Optional[str] = None, + hf_auth_token: Optional[str] = None, + hf_attn_implementation: Optional[str] = None, + hf_trust_remote_code: Optional[bool] = None, + hf_torch_dtype: Optional[str] = None, + hf_device_map: Optional[str] = None, + framework: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Update a provider. + + Args: + provider_id: ID of the provider to update + provider_model_id: New model ID at the provider + is_primary: Whether this is the primary provider + api_endpoint_url: Full URL for the API endpoint + api_http_method: HTTP method (POST, GET) + api_timeout_seconds: Request timeout in seconds + api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE) + api_auth_header_name: Header name for authentication + api_key: API key or token + api_key_prefix: Prefix for the API key (e.g., "Bearer") + api_headers: Additional headers as dict + api_request_template: Request body template as dict + api_response_path: JSON path to extract response text + hf_use_pipeline: For HuggingFace - whether to use pipeline API + hf_model_class: For HuggingFace - model class + hf_auth_token: For HuggingFace - auth token + hf_attn_implementation: For HuggingFace - attention implementation + hf_trust_remote_code: For HuggingFace - trust remote code + hf_torch_dtype: For HuggingFace - torch dtype + hf_device_map: For HuggingFace - device map + framework: Framework (pt, tf) + config: Additional configuration + + Returns: + Dictionary containing updated provider information + """ + mutation = """ + mutation UpdateVersionProvider($input: UpdateVersionProviderInput!) { + updateVersionProvider(input: $input) { + success + data { + id + provider + providerModelId + isPrimary + isActive + } + errors + } + } + """ + + input_data: Dict[str, Any] = {"id": provider_id} + + if provider_model_id is not None: + input_data["providerModelId"] = provider_model_id + if is_primary is not None: + input_data["isPrimary"] = is_primary + # API Configuration + if api_endpoint_url is not None: + input_data["apiEndpointUrl"] = api_endpoint_url + if api_http_method is not None: + input_data["apiHttpMethod"] = api_http_method + if api_timeout_seconds is not None: + input_data["apiTimeoutSeconds"] = api_timeout_seconds + if api_auth_type is not None: + input_data["apiAuthType"] = api_auth_type + if api_auth_header_name is not None: + input_data["apiAuthHeaderName"] = api_auth_header_name + if api_key is not None: + input_data["apiKey"] = api_key + if api_key_prefix is not None: + input_data["apiKeyPrefix"] = api_key_prefix + if api_headers is not None: + input_data["apiHeaders"] = api_headers + if api_request_template is not None: + input_data["apiRequestTemplate"] = api_request_template + if api_response_path is not None: + input_data["apiResponsePath"] = api_response_path + # HuggingFace Configuration + if hf_use_pipeline is not None: + input_data["hfUsePipeline"] = hf_use_pipeline + if hf_model_class is not None: + input_data["hfModelClass"] = hf_model_class + if hf_auth_token is not None: + input_data["hfAuthToken"] = hf_auth_token + if hf_attn_implementation is not None: + input_data["hfAttnImplementation"] = hf_attn_implementation + if hf_trust_remote_code is not None: + input_data["hfTrustRemoteCode"] = hf_trust_remote_code + if hf_torch_dtype is not None: + input_data["hfTorchDtype"] = hf_torch_dtype + if hf_device_map is not None: + input_data["hfDeviceMap"] = hf_device_map + if framework is not None: + input_data["framework"] = framework + if config is not None: + input_data["config"] = config + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"input": input_data}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updateVersionProvider", {}) + return result + + def delete_provider(self, provider_id: int) -> Dict[str, Any]: + """ + Delete a provider. + + Args: + provider_id: ID of the provider to delete + + Returns: + Dictionary containing deletion response + """ + mutation = """ + mutation DeleteVersionProvider($providerId: Int!) { + deleteVersionProvider(providerId: $providerId) { + success + errors + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": mutation, + "variables": {"providerId": provider_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("deleteVersionProvider", {}) + return result + + # ==================== Helper Methods ==================== + + def get_primary_version(self, model_id: int) -> Optional[Dict[str, Any]]: + """ + Get the primary (latest) version for an AI model. + + Args: + model_id: ID of the AI model + + Returns: + Dictionary containing the primary version, or None if no versions exist + """ + versions = self.get_versions(model_id) + for version in versions: + if version.get("isLatest"): + return version + return versions[0] if versions else None + + def get_primary_provider(self, version_id: int) -> Optional[Dict[str, Any]]: + """ + Get the primary provider for a version. + + Args: + version_id: ID of the version + + Returns: + Dictionary containing the primary provider, or None if no providers exist + """ + providers = self.get_version_providers(version_id) + for provider in providers: + if provider.get("isPrimary"): + return provider + return providers[0] if providers else None diff --git a/dataspace_sdk/resources/auditors.py b/dataspace_sdk/resources/auditors.py new file mode 100644 index 00000000..6fbd547c --- /dev/null +++ b/dataspace_sdk/resources/auditors.py @@ -0,0 +1,267 @@ +"""Auditor management resource for DataSpace SDK.""" + +from typing import Any, Dict, Optional + +import requests + +from dataspace_sdk.exceptions import DataSpaceAuthError + + +class AuditorClient: + """ + Client for managing auditors in organizations. + + Auditors are users with the 'auditor' role in an organization, + who can audit AI models registered by that organization. + """ + + def __init__(self, base_url: str, auth_client: Any): + """ + Initialize the auditor client. + + Args: + base_url: Base URL of the DataSpace API + auth_client: Authentication client instance + """ + self._base_url = base_url.rstrip("/") + self._auth = auth_client + self.default_headers: Dict[str, str] = {} + + def _get_headers(self) -> Dict[str, str]: + """Get request headers with authentication.""" + headers = {"Content-Type": "application/json"} + if self.default_headers: + headers.update(self.default_headers) + if self._auth and self._auth.access_token: + headers["Authorization"] = f"Bearer {self._auth.access_token}" + return headers + + def get_organization_auditors(self, organization_id: str) -> Dict[str, Any]: + """ + Get all auditors for an organization. + + Args: + organization_id: UUID of the organization + + Returns: + Dictionary containing: + - organization_id: str + - organization_name: str + - auditors: List of auditor dictionaries + - count: int + + Raises: + DataSpaceAuthError: If not authenticated or permission denied + + Example: + >>> result = client.auditors.get_organization_auditors("org-uuid") + >>> for auditor in result["auditors"]: + ... print(f"{auditor['username']} - {auditor['email']}") + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + response = requests.get( + url, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + raise DataSpaceAuthError(f"Organization {organization_id} not found") + else: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to get auditors"), + status_code=response.status_code, + response=error_data, + ) + + def add_auditor( + self, + organization_id: str, + user_id: Optional[str] = None, + email: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Add a user as auditor to an organization. + + You must provide either user_id or email. If the user is found, + they will be added as an auditor to the organization. + + Args: + organization_id: UUID of the organization + user_id: Optional user ID to add as auditor + email: Optional email to look up user and add as auditor + + Returns: + Dictionary containing: + - success: bool + - message: str + - auditor: Dictionary with auditor details + + Raises: + DataSpaceAuthError: If not authenticated, permission denied, or user not found + ValueError: If neither user_id nor email is provided + + Example: + >>> # Add by user ID + >>> result = client.auditors.add_auditor("org-uuid", user_id="user-uuid") + >>> + >>> # Add by email + >>> result = client.auditors.add_auditor("org-uuid", email="auditor@example.com") + """ + if not user_id and not email: + raise ValueError("Either user_id or email must be provided") + + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + payload: Dict[str, str] = {} + if user_id: + payload["user_id"] = user_id + if email: + payload["email"] = email + + response = requests.post( + url, + json=payload, + headers=self._get_headers(), + ) + + if response.status_code == 201: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Organization or user not found"), + status_code=response.status_code, + response=error_data, + ) + elif response.status_code == 400: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Invalid request"), + status_code=response.status_code, + response=error_data, + ) + else: + error_data = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to add auditor"), + status_code=response.status_code, + response=error_data, + ) + + def remove_auditor(self, organization_id: str, user_id: str) -> Dict[str, Any]: + """ + Remove an auditor from an organization. + + Args: + organization_id: UUID of the organization + user_id: ID of the user to remove as auditor + + Returns: + Dictionary containing: + - success: bool + - message: str + + Raises: + DataSpaceAuthError: If not authenticated, permission denied, or user not an auditor + + Example: + >>> result = client.auditors.remove_auditor("org-uuid", "user-uuid") + >>> print(result["message"]) + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/organizations/{organization_id}/auditors/" + + response = requests.delete( + url, + json={"user_id": user_id}, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + elif response.status_code == 403: + raise DataSpaceAuthError("Permission denied: You must be an admin of this organization") + elif response.status_code == 404: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Organization or auditor not found"), + status_code=response.status_code, + response=error_data, + ) + else: + error_data_other: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data_other.get("error", "Failed to remove auditor"), + status_code=response.status_code, + response=error_data_other, + ) + + def search_user_by_email(self, email: str) -> Dict[str, Any]: + """ + Search for a user by email address. + + This is useful to check if a user exists before adding them as an auditor. + + Args: + email: Email address to search for + + Returns: + Dictionary containing: + - found: bool + - user: Dictionary with user details (if found) + - message: str (if not found) + + Raises: + DataSpaceAuthError: If not authenticated + + Example: + >>> result = client.auditors.search_user_by_email("user@example.com") + >>> if result["found"]: + ... print(f"Found user: {result['user']['username']}") + ... else: + ... print("User not found") + """ + self._auth.ensure_authenticated() + + url = f"{self._base_url}/api/users/search-by-email/" + + response = requests.get( + url, + params={"email": email}, + headers=self._get_headers(), + ) + + if response.status_code == 200: + result: Dict[str, Any] = response.json() + return result + elif response.status_code == 401: + raise DataSpaceAuthError("Authentication required") + else: + error_data: Dict[str, Any] = response.json() + raise DataSpaceAuthError( + error_data.get("error", "Failed to search user"), + status_code=response.status_code, + response=error_data, + ) diff --git a/dataspace_sdk/resources/datasets.py b/dataspace_sdk/resources/datasets.py index bc655e98..24dee57a 100644 --- a/dataspace_sdk/resources/datasets.py +++ b/dataspace_sdk/resources/datasets.py @@ -16,6 +16,7 @@ def search( geographies: Optional[List[str]] = None, status: Optional[str] = None, access_type: Optional[str] = None, + dataset_type: Optional[str] = None, sort: Optional[str] = None, page: int = 1, page_size: int = 10, @@ -30,6 +31,7 @@ def search( geographies: Filter by geographies status: Filter by status (DRAFT, PUBLISHED, etc.) access_type: Filter by access type (OPEN, RESTRICTED, etc.) + dataset_type: Filter by dataset type (DATA, PROMPT) sort: Sort order (recent, alphabetical) page: Page number (1-indexed) page_size: Number of results per page @@ -54,10 +56,12 @@ def search( params["status"] = status if access_type: params["access_type"] = access_type + if dataset_type: + params["dataset_type"] = dataset_type if sort: params["sort"] = sort - return self.get("/api/search/dataset/", params=params) + return super().get("/api/search/dataset/", params=params) def get_by_id(self, dataset_id: str) -> Dict[str, Any]: """ @@ -71,20 +75,23 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: """ query = """ query GetDataset($id: UUID!) { - dataset(id: $id) { + getDataset(datasetId: $id) { id title description status - accessType - license - createdAt - updatedAt + datasetType + created + modified + downloadCount organization { id name description } + user { + id + } tags { id value @@ -99,12 +106,18 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: } resources { id - title + name description - fileDetails - schema - createdAt - updatedAt + fileDetails { + format + size + } + schema { + id + fieldName + format + description + } } } } @@ -123,7 +136,7 @@ def get_by_id(self, dataset_id: str) -> Dict[str, Any]: raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") - result: Dict[str, Any] = response.get("data", {}).get("dataset", {}) + result: Dict[str, Any] = response.get("data", {}).get("getDataset", {}) return result def list_all( @@ -154,8 +167,8 @@ def list_all( status accessType license - createdAt - updatedAt + created + updated organization { id name @@ -231,3 +244,364 @@ def get_organization_datasets( limit=limit, offset=offset, ) + + def create(self, dataset_type: str = "DATA") -> Dict[str, Any]: + """ + Create a new dataset using GraphQL. + + Args: + dataset_type: Type of dataset to create (DATA or PROMPT) + + Returns: + Dictionary containing the created dataset information + """ + query = """ + mutation AddDataset($createInput: CreateDatasetInput) { + addDataset(createInput: $createInput) { + success + errors + data { + id + title + description + status + datasetType + created + updated + } + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"createInput": {"datasetType": dataset_type}}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("addDataset", {}) + return result + + def get_prompt_by_id(self, dataset_id: str) -> Dict[str, Any]: + """ + Get a prompt dataset by ID with prompt-specific metadata. + + Args: + dataset_id: UUID of the prompt dataset + + Returns: + Dictionary containing prompt dataset information including prompt metadata + """ + query = """ + query GetPromptDataset($id: UUID!) { + getDataset(datasetId: $id) { + id + title + description + status + datasetType + created + modified + downloadCount + organization { + id + name + description + } + user { + id + } + tags { + id + value + } + sectors { + id + name + } + geographies { + id + name + } + resources { + id + name + fileDetails { + format + size + } + promptDetails { + promptFormat + hasSystemPrompt + hasExampleResponses + promptCount + } + } + promptMetadata + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"id": dataset_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("getDataset", {}) + return result + + def list_prompts( + self, + status: Optional[str] = None, + task_type: Optional[str] = None, + domain: Optional[str] = None, + organization_id: Optional[str] = None, + include_public: Optional[bool] = False, + limit: int = 10, + offset: int = 0, + ) -> Any: + """ + List prompt datasets with pagination using GraphQL. + + Args: + status: Filter by status (DRAFT, PUBLISHED, etc.) + task_type: Filter by prompt task type + domain: Filter by domain + organization_id: Filter by organization + include_public: Include public datasets + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of prompt datasets + """ + query = """ + query ListPromptDatasets($filters: DatasetFilter, $pagination: OffsetPaginationInput, $include_public: Boolean) { + datasets(filters: $filters, pagination: $pagination, includePublic: $include_public) { + id + title + description + status + accessType + datasetType + created + organization { + id + name + } + tags { + id + value + } + promptMetadata + resources { + id + name + fileDetails { + format + size + } + promptDetails { + promptFormat + hasSystemPrompt + hasExampleResponses + promptCount + } + } + } + } + """ + + filters: Dict[str, Any] = {"datasetType": "PROMPT"} + if status: + filters["status"] = status + if organization_id: + filters["organization"] = {"id": {"exact": organization_id}} + + variables: Dict[str, Any] = { + "pagination": {"limit": limit, "offset": offset}, + "filters": filters, + "include_public": include_public, + } + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": variables, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + data = response.get("data", {}) + datasets_result: Any = data.get("datasets", []) if isinstance(data, dict) else [] + return datasets_result + + def search_prompts( + self, + query: Optional[str] = None, + task_type: Optional[str] = None, + domain: Optional[str] = None, + target_languages: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + sectors: Optional[List[str]] = None, + sort: Optional[str] = None, + page: int = 1, + page_size: int = 10, + ) -> Dict[str, Any]: + """ + Search for prompt datasets specifically. + + Args: + query: Search query string + task_type: Filter by prompt task type (TEXT_GENERATION, QUESTION_ANSWERING, etc.) + domain: Filter by domain (healthcare, education, etc.) + target_languages: Filter by target languages + tags: Filter by tags + sectors: Filter by sectors + sort: Sort order (recent, alphabetical) + page: Page number (1-indexed) + page_size: Number of results per page + + Returns: + Dictionary containing search results and metadata + """ + params: Dict[str, Any] = { + "page": page, + "page_size": page_size, + "dataset_type": "PROMPT", + } + + if query: + params["q"] = query + if task_type: + params["task_type"] = task_type + if domain: + params["domain"] = domain + if target_languages: + params["target_languages"] = ",".join(target_languages) + if tags: + params["tags"] = ",".join(tags) + if sectors: + params["sectors"] = ",".join(sectors) + if sort: + params["sort"] = sort + + return super().get("/api/search/dataset/", params=params) + + def update_prompt_metadata( + self, + dataset_id: str, + task_type: Optional[str] = None, + target_languages: Optional[List[str]] = None, + domain: Optional[str] = None, + target_model_types: Optional[List[str]] = None, + prompt_format: Optional[str] = None, + has_system_prompt: Optional[bool] = None, + has_example_responses: Optional[bool] = None, + avg_prompt_length: Optional[int] = None, + prompt_count: Optional[int] = None, + use_case: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Update prompt-specific metadata for a prompt dataset. + + Args: + dataset_id: UUID of the prompt dataset + task_type: Type of prompt task + target_languages: List of target languages + domain: Domain/category of prompts + target_model_types: List of target AI model types + prompt_format: Format of prompts + has_system_prompt: Whether prompts include system instructions + has_example_responses: Whether prompts include example responses + avg_prompt_length: Average prompt length + prompt_count: Total number of prompts + use_case: Description of intended use cases + + Returns: + Dictionary containing the updated prompt metadata + """ + query = """ + mutation UpdatePromptMetadata($updateInput: UpdatePromptMetadataInput!) { + updatePromptMetadata(updateInput: $updateInput) { + success + errors + data { + id + title + description + status + datasetType + taskType + targetLanguages + domain + targetModelTypes + promptFormat + hasSystemPrompt + hasExampleResponses + avgPromptLength + promptCount + useCase + } + } + } + """ + + variables: Dict[str, Any] = {"dataset": dataset_id} + + if task_type is not None: + variables["taskType"] = task_type + if target_languages is not None: + variables["targetLanguages"] = target_languages + if domain is not None: + variables["domain"] = domain + if target_model_types is not None: + variables["targetModelTypes"] = target_model_types + if prompt_format is not None: + variables["promptFormat"] = prompt_format + if has_system_prompt is not None: + variables["hasSystemPrompt"] = has_system_prompt + if has_example_responses is not None: + variables["hasExampleResponses"] = has_example_responses + if avg_prompt_length is not None: + variables["avgPromptLength"] = avg_prompt_length + if prompt_count is not None: + variables["promptCount"] = prompt_count + if use_case is not None: + variables["useCase"] = use_case + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"updateInput": variables}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("updatePromptMetadata", {}) + return result diff --git a/dataspace_sdk/resources/sectors.py b/dataspace_sdk/resources/sectors.py new file mode 100644 index 00000000..e5e07002 --- /dev/null +++ b/dataspace_sdk/resources/sectors.py @@ -0,0 +1,132 @@ +"""Sector resource client for DataSpace SDK.""" + +from typing import Any, Dict, List, Optional + +from dataspace_sdk.base import BaseAPIClient + + +class SectorClient(BaseAPIClient): + """Client for interacting with Sector resources.""" + + def list_all( + self, + search: Optional[str] = None, + min_dataset_count: Optional[int] = None, + min_aimodel_count: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """ + List all sectors with optional filters using GraphQL. + + Args: + search: Search query for name/description + min_dataset_count: Filter sectors with at least this many published datasets + min_aimodel_count: Filter sectors with at least this many active AI models + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of sector dictionaries + """ + query = """ + query ListSectors($filters: SectorFilter, $pagination: OffsetPaginationInput) { + sectors(filters: $filters, pagination: $pagination) { + id + name + slug + description + datasetCount + aimodelCount + } + } + """ + + filters: Dict[str, Any] = {} + if search: + filters["search"] = search + if min_dataset_count is not None: + filters["minDatasetCount"] = min_dataset_count + if min_aimodel_count is not None: + filters["minAimodelCount"] = min_aimodel_count + + variables: Dict[str, Any] = { + "pagination": {"limit": limit, "offset": offset}, + } + if filters: + variables["filters"] = filters + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": variables, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + data = response.get("data", {}) + sectors_result: List[Dict[str, Any]] = ( + data.get("sectors", []) if isinstance(data, dict) else [] + ) + return sectors_result + + def get_by_id(self, sector_id: str) -> Dict[str, Any]: + """ + Get a sector by ID using GraphQL. + + Args: + sector_id: UUID of the sector + + Returns: + Dictionary containing sector information + """ + query = """ + query GetSector($id: UUID!) { + sector(id: $id) { + id + name + slug + description + datasetCount + aimodelCount + } + } + """ + + response = self.post( + "/api/graphql", + json_data={ + "query": query, + "variables": {"id": sector_id}, + }, + ) + + if "errors" in response: + from dataspace_sdk.exceptions import DataSpaceAPIError + + raise DataSpaceAPIError(f"GraphQL error: {response['errors']}") + + result: Dict[str, Any] = response.get("data", {}).get("sector", {}) + return result + + def get_sectors_with_aimodels( + self, + limit: int = 100, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """ + Get sectors that have at least one active AI model. + + Args: + limit: Number of results to return + offset: Number of results to skip + + Returns: + List of sector dictionaries with AI models + """ + return self.list_all(min_aimodel_count=1, limit=limit, offset=offset) diff --git a/dataspace_sdk/resources/usecases.py b/dataspace_sdk/resources/usecases.py index 3d75a2f9..6aca2cd1 100644 --- a/dataspace_sdk/resources/usecases.py +++ b/dataspace_sdk/resources/usecases.py @@ -57,7 +57,7 @@ def search( if sort: params["sort"] = sort - return self.get("/api/search/usecase/", params=params) + return super().get("/api/search/usecase/", params=params) def get_by_id(self, usecase_id: int) -> Dict[str, Any]: """ diff --git a/docs/sdk/AUTHENTICATION.md b/docs/sdk/AUTHENTICATION.md new file mode 100644 index 00000000..65aeb132 --- /dev/null +++ b/docs/sdk/AUTHENTICATION.md @@ -0,0 +1,369 @@ +# DataSpace SDK Authentication Guide + +This guide explains how to authenticate with the DataSpace SDK using username/password with automatic token management. + +## Table of Contents + +- [Quick Start](#quick-start) +- [Authentication Methods](#authentication-methods) +- [Automatic Token Management](#automatic-token-management) +- [Configuration](#configuration) +- [Error Handling](#error-handling) +- [Best Practices](#best-practices) + +## Quick Start + +### Username/Password Login (Recommended) + +```python +from dataspace_sdk import DataSpaceClient + +# Initialize with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +# Login once - credentials stored for auto-refresh +user_info = client.login( + username="your-email@example.com", + password="your-password" +) + +# Now use the client - tokens auto-refresh! +datasets = client.datasets.search(query="health") +``` + +## Authentication Methods + +### 1. Username/Password (Recommended) + +Best for: +- Scripts and automation +- Long-running applications +- Development and testing + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +user_info = client.login( + username="user@example.com", + password="secret" +) +``` + +**Features:** +- ✅ Automatic token refresh +- ✅ Automatic re-login if refresh fails +- ✅ No manual token management needed +- ✅ Credentials stored securely in memory + +### 2. Pre-obtained Keycloak Token + +Best for: +- When you already have a token from another source +- Browser-based applications +- SSO integrations + +```python +client = DataSpaceClient(base_url="https://dataspace.civicdatalab.in") + +user_info = client.login_with_token(keycloak_token="eyJhbGci...") +``` + +**Note:** This method does NOT support automatic token refresh or re-login. + +## Automatic Token Management + +The SDK automatically handles token expiration and refresh: + +### How It Works + +1. **Login**: You provide username/password once +2. **Token Storage**: SDK stores credentials securely in memory +3. **Auto-Refresh**: When token expires (within 30 seconds), SDK automatically refreshes it +4. **Auto-Relogin**: If refresh fails, SDK automatically re-authenticates with stored credentials +5. **Transparent**: All of this happens automatically - you don't need to do anything! + +### Example: Long-Running Script + +```python +import time +from dataspace_sdk import DataSpaceClient + +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +# Login once +client.login(username="user@example.com", password="secret") + +# Run for hours - tokens auto-refresh! +while True: + datasets = client.datasets.search(query="health") + print(f"Found {len(datasets.get('results', []))} datasets") + + # Sleep for 10 minutes + time.sleep(600) + + # SDK automatically refreshes tokens as needed + # No manual intervention required! +``` + +## Configuration + +### Keycloak Configuration + +You need these details from your Keycloak setup: + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", # DataSpace API URL + keycloak_url="https://opub-kc.civicdatalab.in", # Keycloak server URL + keycloak_realm="DataSpace", # Realm name + keycloak_client_id="dataspace", # Client ID + keycloak_client_secret="optional-secret" # Only for confidential clients +) +``` + +### Finding Your Keycloak Details + +1. **Keycloak URL**: Your Keycloak server address +2. **Realm**: Usually shown in Keycloak admin console +3. **Client ID**: Found in Keycloak → Clients → Your Client +4. **Client Secret**: Only needed if client is "confidential" (check Access Type in Keycloak) + +### Environment Variables (Recommended) + +Store credentials securely: + +```python +import os +from dataspace_sdk import DataSpaceClient + +client = DataSpaceClient( + base_url=os.getenv("DATASPACE_API_URL"), + keycloak_url=os.getenv("KEYCLOAK_URL"), + keycloak_realm=os.getenv("KEYCLOAK_REALM"), + keycloak_client_id=os.getenv("KEYCLOAK_CLIENT_ID"), + keycloak_client_secret=os.getenv("KEYCLOAK_CLIENT_SECRET") # Optional +) + +client.login( + username=os.getenv("DATASPACE_USERNAME"), + password=os.getenv("DATASPACE_PASSWORD") +) +``` + +Create a `.env` file: + +```bash +DATASPACE_API_URL=https://dataspace.civicdatalab.in +KEYCLOAK_URL=https://opub-kc.civicdatalab.in +KEYCLOAK_REALM=DataSpace +KEYCLOAK_CLIENT_ID=dataspace +DATASPACE_USERNAME=your-email@example.com +DATASPACE_PASSWORD=your-password +``` + +## Error Handling + +### Authentication Errors + +```python +from dataspace_sdk import DataSpaceClient +from dataspace_sdk.exceptions import DataSpaceAuthError + +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) + +try: + user_info = client.login( + username="user@example.com", + password="wrong-password" + ) +except DataSpaceAuthError as e: + print(f"Login failed: {e}") + print(f"Status code: {e.status_code}") + print(f"Response: {e.response}") +``` + +### Common Errors + +| Error | Cause | Solution | +|-------|-------|----------| +| `Keycloak configuration missing` | Missing keycloak_url, realm, or client_id | Provide all required Keycloak parameters | +| `Keycloak login failed: invalid_grant` | Wrong username/password | Check credentials | +| `Keycloak login failed: invalid_client` | Wrong client_id or client requires consent | Check client_id or disable consent in Keycloak | +| `Resource not found` | Wrong Keycloak URL or realm | Verify keycloak_url and realm name | +| `Not authenticated` | Trying to use API before login | Call `client.login()` first | + +### Checking Authentication Status + +```python +# Check if authenticated +if client.is_authenticated(): + print("Authenticated!") +else: + print("Not authenticated") + +# Get current user info +user = client.user +if user: + print(f"Logged in as: {user['username']}") +``` + +## Best Practices + +### 1. Use Environment Variables + +Never hardcode credentials: + +```python +# ❌ Bad +client.login(username="user@example.com", password="secret123") + +# ✅ Good +client.login( + username=os.getenv("DATASPACE_USERNAME"), + password=os.getenv("DATASPACE_PASSWORD") +) +``` + +### 2. Login Once, Use Everywhere + +```python +# ✅ Good - Login once at startup +client = DataSpaceClient(...) +client.login(username=..., password=...) + +# Use client throughout your application +# Tokens auto-refresh! +datasets = client.datasets.search(...) +models = client.aimodels.search(...) +``` + +### 3. Handle Errors Gracefully + +```python +from dataspace_sdk.exceptions import DataSpaceAuthError + +try: + client.login(username=username, password=password) +except DataSpaceAuthError as e: + logger.error(f"Authentication failed: {e}") + # Handle error appropriately + sys.exit(1) +``` + +### 4. Use Confidential Clients for Production + +For production applications, use a confidential client with client_secret: + +```python +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace-prod", + keycloak_client_secret=os.getenv("KEYCLOAK_CLIENT_SECRET") +) +``` + +### 5. Don't Store Passwords in Code + +```python +# ❌ Bad +PASSWORD = "secret123" + +# ✅ Good - Use environment variables +password = os.getenv("DATASPACE_PASSWORD") + +# ✅ Better - Use secrets management +from your_secrets_manager import get_secret +password = get_secret("dataspace_password") +``` + +## Advanced Usage + +### Manual Token Refresh + +While automatic refresh is recommended, you can manually refresh: + +```python +# Get new access token +new_token = client.refresh_token() +print(f"New token: {new_token}") +``` + +### Get Current Access Token + +```python +# Get current token (e.g., to pass to another service) +token = client.access_token +print(f"Current token: {token}") +``` + +### Ensure Authentication + +Force authentication check and auto-relogin if needed: + +```python +# This will auto-relogin if not authenticated +client._auth.ensure_authenticated() +``` + +## Troubleshooting + +### Token Keeps Expiring + +If you're experiencing frequent token expiration: + +1. **Check token lifetime**: Keycloak admin → Realm Settings → Tokens → Access Token Lifespan +2. **Verify auto-refresh**: SDK automatically refreshes 30 seconds before expiration +3. **Check credentials**: Ensure username/password are stored correctly + +### "Resource not found" Error + +This usually means wrong Keycloak URL: + +```python +# Try with /auth prefix +keycloak_url="https://opub-kc.civicdatalab.in/auth" + +# Or without +keycloak_url="https://opub-kc.civicdatalab.in" + +# Test in browser: +# https://opub-kc.civicdatalab.in/auth/realms/DataSpace/.well-known/openid-configuration +``` + +### "Client requires user consent" Error + +In Keycloak admin console: + +1. Go to Clients → your client +2. Find "Consent Required" setting +3. Set to OFF +4. Save + +## Next Steps + +- [Quick Start Guide](QUICKSTART.md) +- [Full SDK Documentation](README.md) +- [Examples](../../examples/) +- [API Reference](API_REFERENCE.md) diff --git a/docs/sdk/QUICKSTART.md b/docs/sdk/QUICKSTART.md index 915432bb..6138293c 100644 --- a/docs/sdk/QUICKSTART.md +++ b/docs/sdk/QUICKSTART.md @@ -21,13 +21,23 @@ pip install -e . ```python from dataspace_sdk import DataSpaceClient -# 1. Initialize the client -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +# 1. Initialize the client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace" +) -# 2. Login with your Keycloak token -client.login(keycloak_token="your_keycloak_token") +# 2. Login with username and password +# Credentials are stored securely for automatic token refresh +user_info = client.login( + username="your-email@example.com", + password="your-password" +) +print(f"Logged in as: {user_info['user']['username']}") -# 3. Search for datasets +# 3. Search for datasets (tokens auto-refresh as needed) datasets = client.datasets.search(query="health", page_size=5) print(f"Found {datasets['total']} datasets") @@ -36,6 +46,18 @@ dataset = client.datasets.get_by_id("dataset-uuid") print(f"Dataset: {dataset['title']}") ``` +### Alternative: Login with Pre-obtained Token + +If you already have a Keycloak token: + +```python +# Initialize without Keycloak config +client = DataSpaceClient(base_url="https://dataspace.civicdatalab.in") + +# Login with token +client.login_with_token(keycloak_token="your_keycloak_token") +``` + ## Common Operations ### Search Resources diff --git a/docs/sdk/README.md b/docs/sdk/README.md index 8a251cf6..19a67473 100644 --- a/docs/sdk/README.md +++ b/docs/sdk/README.md @@ -29,11 +29,20 @@ pip install -e ".[dev]" ```python from dataspace_sdk import DataSpaceClient -# Initialize the client -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +# Initialize the client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dev.api.civicdataspace.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + keycloak_client_secret="your_client_secret" +) -# Login with Keycloak token -user_info = client.login(keycloak_token="your_keycloak_token") +# Login with username and password +user_info = client.login( + username="your_email@example.com", + password="your_password" +) print(f"Logged in as: {user_info['user']['username']}") # Search for datasets @@ -49,36 +58,66 @@ dataset = client.datasets.get_by_id("dataset-uuid") print(f"Dataset: {dataset['title']}") # Get organization's resources -org_id = user_info['user']['organizations'][0]['id'] -org_datasets = client.datasets.get_organization_datasets(org_id) +if user_info['user']['organizations']: + org_id = user_info['user']['organizations'][0]['id'] + org_datasets = client.datasets.get_organization_datasets(org_id) ``` ## Features -- **Authentication**: Login with Keycloak tokens and automatic token refresh +- **Authentication**: Multiple authentication methods (username/password, Keycloak token, service account) +- **Automatic Token Management**: Automatic token refresh and re-login - **Datasets**: Search, retrieve, and list datasets with filtering and pagination -- **AI Models**: Search, retrieve, and list AI models with filtering +- **AI Models**: Search, retrieve, call, and list AI models with filtering - **Use Cases**: Search, retrieve, and list use cases with filtering - **Organization Resources**: Get resources specific to your organizations - **GraphQL & REST**: Supports both GraphQL and REST API endpoints +- **Error Handling**: Comprehensive exception handling with detailed error messages ## Authentication -### Login with Keycloak +The SDK supports three authentication methods: + +### 1. Username and Password (Recommended for Users) ```python from dataspace_sdk import DataSpaceClient -client = DataSpaceClient(base_url="https://api.dataspace.example.com") +client = DataSpaceClient( + base_url="https://dev.api.civicdataspace.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + keycloak_client_secret="your_client_secret" +) -# Login with Keycloak token -response = client.login(keycloak_token="your_keycloak_token") +# Login with username and password +user_info = client.login( + username="your_email@example.com", + password="your_password" +) # Access user information -print(response['user']['username']) -print(response['user']['organizations']) +print(user_info['user']['username']) +print(user_info['user']['organizations']) ``` +### 2. Keycloak Token (For Token Pass-through) + +```python +# Login with an existing Keycloak token +response = client.login_with_token(keycloak_token="your_keycloak_token") +``` + +### 3. Service Account (For Backend Services) + +```python +# Login as a service account using client credentials +service_info = client.login_as_service_account() +``` + +For detailed authentication documentation, see [AUTHENTICATION_COMPLETE.md](./AUTHENTICATION_COMPLETE.md) + ### Token Refresh ```python @@ -203,6 +242,36 @@ print(f"Provider: {model['provider']}") print(f"Endpoints: {len(model['endpoints'])}") ``` +### Call an AI Model + +```python +# Call an AI model with input text +result = client.aimodels.call_model( + model_id="model-uuid", + input_text="What is the capital of France?", + parameters={ + "temperature": 0.7, + "max_tokens": 100 + } +) + +if result['success']: + print(f"Output: {result['output']}") + print(f"Latency: {result['latency_ms']}ms") + print(f"Provider: {result['provider']}") +else: + print(f"Error: {result['error']}") + +# For long-running operations, use async call +task = client.aimodels.call_model_async( + model_id="model-uuid", + input_text="Generate a long document...", + parameters={"max_tokens": 2000} +) +print(f"Task ID: {task['task_id']}") +print(f"Status: {task['status']}") +``` + ### List All AI Models ```python @@ -249,6 +318,7 @@ results = client.usecases.search( ### Get Use Case by ID ```python +# Get use case by ID usecase = client.usecases.get_by_id(123) print(f"Title: {usecase['title']}") @@ -375,7 +445,9 @@ Main client for interacting with DataSpace API. **Methods:** -- `login(keycloak_token: str) -> dict`: Login with Keycloak token +- `login(username: str, password: str) -> dict`: Login with username and password +- `login_with_token(keycloak_token: str) -> dict`: Login with Keycloak token +- `login_as_service_account() -> dict`: Login as service account (client credentials) - `refresh_token() -> str`: Refresh access token - `get_user_info() -> dict`: Get current user information - `is_authenticated() -> bool`: Check authentication status @@ -395,10 +467,12 @@ Client for dataset operations. **Methods:** - `search(...)`: Search datasets with filters -- `get_by_id(dataset_id: str)`: Get dataset by UUID +- `get_by_id(dataset_id: str)`: Get dataset by UUID (GraphQL) - `list_all(...)`: List all datasets with pagination - `get_trending(limit: int)`: Get trending datasets - `get_organization_datasets(organization_id: str, ...)`: Get organization's datasets +- `get_resources(dataset_id: str)`: Get dataset resources +- `list_by_organization(organization_id: str, ...)`: List datasets by organization ### AIModelClient @@ -409,8 +483,13 @@ Client for AI model operations. - `search(...)`: Search AI models with filters - `get_by_id(model_id: str)`: Get AI model by UUID (REST) - `get_by_id_graphql(model_id: str)`: Get AI model by UUID (GraphQL) +- `call_model(model_id: str, input_text: str, parameters: dict)`: Call an AI model +- `call_model_async(model_id: str, input_text: str, parameters: dict)`: Call an AI model asynchronously - `list_all(...)`: List all AI models with pagination - `get_organization_models(organization_id: str, ...)`: Get organization's AI models +- `create(data: dict)`: Create a new AI model +- `update(model_id: str, data: dict)`: Update an AI model +- `delete_model(model_id: str)`: Delete an AI model ### UseCaseClient @@ -419,7 +498,7 @@ Client for use case operations. **Methods:** - `search(...)`: Search use cases with filters -- `get_by_id(usecase_id: int)`: Get use case by ID +- `get_by_id(usecase_id: int)`: Get use case by ID (GraphQL) - `list_all(...)`: List all use cases with pagination - `get_organization_usecases(organization_id: str, ...)`: Get organization's use cases diff --git a/examples/username_password_login.py b/examples/username_password_login.py new file mode 100644 index 00000000..c4772287 --- /dev/null +++ b/examples/username_password_login.py @@ -0,0 +1,69 @@ +""" +Example: Login with username and password with automatic token refresh. + +This example demonstrates: +1. Initializing the client with Keycloak configuration +2. Logging in with username and password +3. Automatic token refresh when tokens expire +4. Automatic re-login if refresh fails +""" + +from dataspace_sdk import DataSpaceClient +from dataspace_sdk.exceptions import DataSpaceAuthError + +# Initialize client with Keycloak configuration +client = DataSpaceClient( + base_url="https://dataspace.civicdatalab.in", + keycloak_url="https://opub-kc.civicdatalab.in", + keycloak_realm="DataSpace", + keycloak_client_id="dataspace", + # keycloak_client_secret="your-secret" # Only if using confidential client +) + +# Login with username and password +# Credentials are stored securely for automatic re-login +try: + user_info = client.login(username="your-email@example.com", password="your-password") + print(f"✓ Logged in as: {user_info['user']['username']}") + print(f"✓ Organizations: {[org['name'] for org in user_info['user'].get('organizations', [])]}") +except DataSpaceAuthError as e: + print(f"✗ Login failed: {e}") + exit(1) + +# Now you can use the client normally +# Tokens will be automatically refreshed when they expire + +# Example 1: Search datasets +print("\n--- Searching datasets ---") +datasets = client.datasets.search(query="health", page_size=5) +print(f"Found {len(datasets.get('results', []))} datasets") + +# Example 2: Get user's organization datasets +print("\n--- Getting organization datasets ---") +user_orgs = user_info["user"].get("organizations", []) +if user_orgs: + org_id = user_orgs[0]["id"] + org_datasets = client.datasets.get_organization_datasets(org_id, limit=10, offset=0) + print(f"Organization has datasets: {org_datasets}") + +# Example 3: Token automatically refreshes +# Even if you use the client after tokens expire, it will auto-refresh +print("\n--- Simulating long-running session ---") +print("The SDK will automatically refresh tokens as needed...") + +# You can also manually check authentication status +if client.is_authenticated(): + print("✓ Still authenticated") +else: + print("✗ Not authenticated") + +# Example 4: Get current user info (will auto-refresh if needed) +print("\n--- Getting user info ---") +try: + current_user = client.get_user_info() + print(f"Current user: {current_user.get('username')}") +except DataSpaceAuthError as e: + print(f"Error: {e}") + +print("\n✓ All operations completed successfully!") +print("Note: The SDK automatically handled token refresh and re-login as needed.") diff --git a/files/public/resources/dummydistrict.csv b/files/public/resources/dummydistrict.csv deleted file mode 100644 index 1a8dd45d..00000000 --- a/files/public/resources/dummydistrict.csv +++ /dev/null @@ -1,3 +0,0 @@ -district,Value,,,,, -DHUBRI,10,,,,, -South Salmara Mancachar,30,,,,, diff --git a/pyproject.toml b/pyproject.toml index 721b5d2a..d1e5fa80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataspace-sdk" -version = "0.4.18" +version = "0.4.19" description = "Python SDK for DataSpace API" readme = "docs/sdk/README.md" requires-python = ">=3.8" @@ -53,7 +53,8 @@ target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' [tool.mypy] -python_version = "0.4.18" + +python_version = "0.4.19" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false diff --git a/requirements.txt b/requirements.txt index e1bd9190..56a97311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,26 +18,26 @@ elasticsearch-dsl==8.12.0 googleapis-common-protos==1.65.0 graphql-core==3.2.3 gunicorn==23.0.0 -grpcio==1.62.1 +grpcio==1.68.1 idna==3.6 importlib-metadata==6.11.0 numpy==1.26.4 -opentelemetry-api==1.23.0 -opentelemetry-exporter-otlp==1.23.0 -opentelemetry-exporter-otlp-proto-common==1.23.0 -opentelemetry-exporter-otlp-proto-grpc==1.23.0 -opentelemetry-exporter-otlp-proto-http==1.23.0 -opentelemetry-instrumentation==0.44b0 -opentelemetry-instrumentation-django==0.44b0 -opentelemetry-instrumentation-wsgi==0.44b0 -opentelemetry-instrumentation-redis==0.44b0 -opentelemetry-instrumentation-elasticsearch==0.44b0 -opentelemetry-instrumentation-requests==0.44b0 -opentelemetry-instrumentation-sqlalchemy==0.44b0 -opentelemetry-proto==1.23.0 -opentelemetry-sdk==1.23.0 -opentelemetry-semantic-conventions==0.44b0 -opentelemetry-util-http==0.44b0 +opentelemetry-api==1.28.2 +opentelemetry-exporter-otlp==1.28.2 +opentelemetry-exporter-otlp-proto-common==1.28.2 +opentelemetry-exporter-otlp-proto-grpc==1.28.2 +opentelemetry-exporter-otlp-proto-http==1.28.2 +opentelemetry-instrumentation==0.49b2 +opentelemetry-instrumentation-django==0.49b2 +opentelemetry-instrumentation-wsgi==0.49b2 +opentelemetry-instrumentation-redis==0.49b2 +opentelemetry-instrumentation-elasticsearch==0.49b2 +opentelemetry-instrumentation-requests==0.49b2 +opentelemetry-instrumentation-sqlalchemy==0.49b2 +opentelemetry-proto==1.28.2 +opentelemetry-sdk==1.28.2 +opentelemetry-semantic-conventions==0.49b2 +opentelemetry-util-http==0.49b2 packaging==24.0 pandas==2.2.2 openpyxl==3.1.2 # For Excel file support @@ -45,42 +45,42 @@ odfpy==1.4.1 # For ODS file support pyarrow==15.0.0 # For Parquet and Feather file support xlrd==2.0.1 # For Excel file support pillow==10.4.0 -protobuf==4.25.3 -pydantic==2.6.1 # Slightly downgraded for compatibility -pydantic_core==2.16.2 # Matched with pydantic +protobuf==5.28.3 +pydantic==2.9.2 +pydantic_core==2.23.4 # Matched with pydantic pyecharts==2.0.3 pyecharts-snapshot @ git+https://github.com/Deepthi-Chand/pyecharts-snapshot.git@8d6cadd055db6c919a1447064185d00d1b30ce01 python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 +python-dotenv==1.1.1 python-magic==0.4.27 pytz==2024.1 requests==2.31.0 sqlalchemy==2.0.39 six==1.16.0 sqlparse==0.4.4 -strawberry-graphql==0.225.1 -strawberry-graphql-django==0.37.1 -typing-extensions==4.9.0 # Fixed version that satisfies most dependencies +strawberry-graphql==0.235.2 +strawberry-graphql-django==0.42.0 +typing-extensions==4.15.0 tzdata==2024.1 urllib3==2.2.1 -uvicorn==0.27.1 # Downgraded for better compatibility +uvicorn==0.27.1 wrapt==1.16.0 zipp==3.18.1 snapshot-selenium==0.0.2 -selenium==4.16.0 # Downgraded for compatibility +selenium==4.16.0 # Security and Environment Management python-decouple==3.8 -django-debug-toolbar==4.2.0 # Downgraded for better Django compatibility +django-debug-toolbar==4.2.0 django-ratelimit==4.1.0 # Caching and Performance -django-redis==5.4.0 # Added for Redis caching -redis==5.0.1 # Added Redis client +django-redis==5.4.0 +redis==5.0.1 django-cacheops==7.0.2 # Code Quality and Type Checking -mypy==1.7.1 # Downgraded for compatibility +mypy==1.7.1 black==24.2.0 isort==5.13.2 pre-commit==3.6.2 @@ -88,10 +88,10 @@ flake8==7.0.0 pandas-stubs==2.2.0.240218 # Matched with pandas version # Data Comparison and Versioning -deepdiff==6.7.1 # For intelligent version detection +deepdiff==6.7.1 # Logging -structlog==24.1.0 # Downgraded for better compatibility +structlog==24.1.0 # Development and Debugging ipython==8.22.1 @@ -99,20 +99,26 @@ ipython==8.22.1 # API Documentation drf-yasg==1.21.7 drf-yasg[validation]==1.21.7 -djangorestframework-simplejwt==5.3.1 # Downgraded for better DRF compatibility +djangorestframework-simplejwt==5.3.1 # Keycloak Integration python-keycloak==5.5.1 django-keycloak-auth==1.0.0 # API Versioning and Throttling -django-filter==23.5 # Downgraded for better Django compatibility +django-filter==23.5 -django-stubs==4.2.7 # Downgraded for better compatibility -djangorestframework-stubs==3.14.5 # Matched with DRF version +django-stubs==4.2.7 +djangorestframework-stubs==3.14.5 #whitenoise for managing static files whitenoise==6.9.0 # Activity stream for tracking user actions django-activity-stream==2.0.0 + +tenacity==9.1.2 +torch==2.9.0 +transformers==4.57.1 +sentencepiece==0.2.1 +accelerate==1.11.0 diff --git a/search/documents/__init__.py b/search/documents/__init__.py index f6441acf..96a2345c 100644 --- a/search/documents/__init__.py +++ b/search/documents/__init__.py @@ -1,3 +1,8 @@ from search.documents.aimodel_document import AIModelDocument +from search.documents.collaborative_document import CollaborativeDocument from search.documents.dataset_document import DatasetDocument +from search.documents.publisher_document import ( + OrganizationPublisherDocument, + UserPublisherDocument, +) from search.documents.usecase_document import UseCaseDocument diff --git a/search/documents/aimodel_document.py b/search/documents/aimodel_document.py index 164f2d51..fdeae8bc 100644 --- a/search/documents/aimodel_document.py +++ b/search/documents/aimodel_document.py @@ -5,6 +5,7 @@ from django_elasticsearch_dsl import Document, Index, KeywordField, fields from api.models.AIModel import AIModel, ModelEndpoint +from api.models.AIModelVersion import AIModelVersion, VersionProvider from api.models.Dataset import Tag from api.models.Geography import Geography from api.models.Organization import Organization @@ -45,11 +46,15 @@ class AIModelDocument(Document): ) version = fields.KeywordField() + provider = fields.KeywordField() + provider_model_id = fields.KeywordField() + supported_languages = fields.KeywordField(multi=True) + supports_streaming = fields.BooleanField() + max_tokens = fields.IntegerField() # Model configuration model_type = fields.KeywordField() - provider = fields.KeywordField() - provider_model_id = fields.KeywordField() + domain = fields.KeywordField() # Status and visibility status = fields.KeywordField() @@ -89,13 +94,6 @@ class AIModelDocument(Document): multi=True, ) - # Supported languages (stored as JSON array in model) - supported_languages = fields.KeywordField(multi=True) - - # Capabilities - supports_streaming = fields.BooleanField() - max_tokens = fields.IntegerField() - # Performance metrics average_latency_ms = fields.FloatField() success_rate = fields.FloatField() @@ -130,10 +128,47 @@ class AIModelDocument(Document): } ) + versions = fields.NestedField( + properties={ + "id": fields.IntegerField(), + "version": fields.KeywordField(), + "version_notes": fields.TextField(analyzer=html_strip), + "lifecycle_stage": fields.KeywordField(), + "is_latest": fields.BooleanField(), + "status": fields.KeywordField(), + "supports_streaming": fields.BooleanField(), + "max_tokens": fields.IntegerField(), + "supported_languages": fields.KeywordField(multi=True), + "created_at": fields.DateField(), + "updated_at": fields.DateField(), + "providers": fields.NestedField( + properties={ + "id": fields.IntegerField(), + "provider": fields.KeywordField(), + "provider_model_id": fields.KeywordField(), + "is_primary": fields.BooleanField(), + "is_active": fields.BooleanField(), + # API Configuration + "api_endpoint_url": fields.KeywordField(), + "api_http_method": fields.KeywordField(), + "api_timeout_seconds": fields.IntegerField(), + "api_auth_type": fields.KeywordField(), + # HuggingFace Configuration + "hf_use_pipeline": fields.BooleanField(), + "hf_model_class": fields.KeywordField(), + "framework": fields.KeywordField(), + } + ), + } + ) + # Computed fields is_individual_model = fields.BooleanField() has_active_endpoints = fields.BooleanField() endpoint_count = fields.IntegerField() + version_count = fields.IntegerField() + lifecycle_stage = fields.KeywordField() # Primary version's lifecycle stage + all_providers = fields.KeywordField(multi=True) # All unique providers across versions def prepare_organization(self, instance: AIModel) -> Optional[Dict[str, str]]: """Prepare organization data for indexing, including logo URL.""" @@ -150,9 +185,7 @@ def prepare_user(self, instance: AIModel) -> Optional[Dict[str, str]]: "name": instance.user.full_name, "bio": instance.user.bio or "", "profile_picture": ( - instance.user.profile_picture.url - if instance.user.profile_picture - else "" + instance.user.profile_picture.url if instance.user.profile_picture else "" ), } return None @@ -184,6 +217,140 @@ def prepare_endpoint_count(self, instance: AIModel) -> int: """Count the number of endpoints.""" return instance.endpoints.count() + def prepare_versions(self, instance: AIModel) -> List[Dict[str, Any]]: + """Prepare versions data for indexing.""" + versions_data: List[Dict[str, Any]] = [] + for version in instance.versions.all(): # type: ignore[attr-defined] + version_obj: AIModelVersion = version # type: ignore[assignment] + providers_data: List[Dict[str, Any]] = [] + for provider in version_obj.providers.all(): # type: ignore[attr-defined] + provider_obj: VersionProvider = provider # type: ignore[assignment] + providers_data.append( + { + "id": provider_obj.id, + "provider": provider_obj.provider, + "provider_model_id": provider_obj.provider_model_id, + "is_primary": provider_obj.is_primary, + "is_active": provider_obj.is_active, + # API Configuration + "api_endpoint_url": provider_obj.api_endpoint_url, + "api_http_method": provider_obj.api_http_method, + "api_timeout_seconds": provider_obj.api_timeout_seconds, + "api_auth_type": provider_obj.api_auth_type, + # HuggingFace Configuration + "hf_use_pipeline": provider_obj.hf_use_pipeline, + "hf_model_class": provider_obj.hf_model_class, + "framework": provider_obj.framework, + } + ) + versions_data.append( + { + "id": version_obj.id, + "version": version_obj.version, + "version_notes": version_obj.version_notes or "", + "lifecycle_stage": version_obj.lifecycle_stage, + "is_latest": version_obj.is_latest, + "status": version_obj.status, + "supports_streaming": version_obj.supports_streaming, + "max_tokens": version_obj.max_tokens, + "supported_languages": version_obj.supported_languages or [], + "created_at": version_obj.created_at, + "updated_at": version_obj.updated_at, + "providers": providers_data, + } + ) + return versions_data + + def _get_primary_version(self, instance: AIModel) -> Optional[AIModelVersion]: + """Get the primary (latest) version of the model.""" + primary = instance.versions.filter(is_latest=True).first() # type: ignore[attr-defined] + if not primary: + primary = instance.versions.first() # type: ignore[attr-defined] + return primary # type: ignore[return-value] + + def _get_primary_provider(self, version: Optional[AIModelVersion]) -> Optional[VersionProvider]: + """Get the primary provider of a version.""" + if not version: + return None + primary = version.providers.filter(is_primary=True).first() # type: ignore[attr-defined] + if not primary: + primary = version.providers.first() # type: ignore[attr-defined] + return primary # type: ignore[return-value] + + def prepare_version(self, instance: AIModel) -> str: + """Prepare version from primary version for backward compatibility.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return str(primary_version.version) + # Fallback to legacy field on AIModel + return instance.version or "" + + def prepare_provider(self, instance: AIModel) -> str: + """Prepare provider from primary version's primary provider for backward compatibility.""" + primary_version = self._get_primary_version(instance) + primary_provider = self._get_primary_provider(primary_version) + if primary_provider: + return str(primary_provider.provider) + # Fallback to legacy field on AIModel + return instance.provider or "" + + def prepare_provider_model_id(self, instance: AIModel) -> str: + """Prepare provider_model_id from primary version's primary provider.""" + primary_version = self._get_primary_version(instance) + primary_provider = self._get_primary_provider(primary_version) + if primary_provider: + return primary_provider.provider_model_id or "" + # Fallback to legacy field on AIModel + return instance.provider_model_id or "" + + def prepare_supported_languages(self, instance: AIModel) -> List[str]: + """Prepare supported_languages from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version and primary_version.supported_languages: + return list(primary_version.supported_languages) + # Fallback to legacy field on AIModel + return list(instance.supported_languages or []) + + def prepare_supports_streaming(self, instance: AIModel) -> bool: + """Prepare supports_streaming from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return bool(primary_version.supports_streaming) + # Fallback to legacy field on AIModel + return bool(instance.supports_streaming) + + def prepare_max_tokens(self, instance: AIModel) -> Optional[int]: + """Prepare max_tokens from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version and primary_version.max_tokens: + return int(primary_version.max_tokens) + # Fallback to legacy field on AIModel + return instance.max_tokens + + def prepare_version_count(self, instance: AIModel) -> int: + """Count the number of versions.""" + return instance.versions.count() # type: ignore[attr-defined] + + def prepare_lifecycle_stage(self, instance: AIModel) -> str: + """Get lifecycle stage from primary version.""" + primary_version = self._get_primary_version(instance) + if primary_version: + return str(primary_version.lifecycle_stage) + return "DEVELOPMENT" + + def prepare_all_providers(self, instance: AIModel) -> List[str]: + """Get all unique providers across all versions.""" + providers: set[str] = set() + for version in instance.versions.all(): # type: ignore[attr-defined] + version_obj: AIModelVersion = version # type: ignore[assignment] + for provider in version_obj.providers.all(): # type: ignore[attr-defined] + provider_obj: VersionProvider = provider # type: ignore[assignment] + providers.add(str(provider_obj.provider)) + # Also include legacy provider if set + if instance.provider: + providers.add(str(instance.provider)) + return list(providers) + def should_index_object(self, obj: AIModel) -> bool: """ Check if the object should be indexed. @@ -225,7 +392,14 @@ def get_queryset(self) -> Any: def get_instances_from_related( self, related_instance: Union[ - ModelEndpoint, Organization, User, Tag, Sector, Geography + ModelEndpoint, + Organization, + User, + Tag, + Sector, + Geography, + AIModelVersion, + VersionProvider, ], ) -> Optional[Union[AIModel, List[AIModel]]]: """Get AIModel instances from related models.""" @@ -241,6 +415,10 @@ def get_instances_from_related( return list(related_instance.ai_models.all()) elif isinstance(related_instance, Geography): return list(related_instance.ai_models.all()) + elif isinstance(related_instance, AIModelVersion): + return related_instance.ai_model + elif isinstance(related_instance, VersionProvider): + return related_instance.version.ai_model return None class Django: @@ -262,4 +440,6 @@ class Django: Tag, Sector, Geography, + AIModelVersion, + VersionProvider, ] diff --git a/search/documents/collaborative_document.py b/search/documents/collaborative_document.py new file mode 100644 index 00000000..2661deb0 --- /dev/null +++ b/search/documents/collaborative_document.py @@ -0,0 +1,335 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from django_elasticsearch_dsl import Document, Index, KeywordField, fields + +from api.models import ( + Collaborative, + CollaborativeMetadata, + CollaborativeOrganizationRelationship, + Dataset, + Geography, + Metadata, + Organization, + Sector, + UseCase, +) +from api.utils.enums import CollaborativeStatus +from authorization.models import User +from DataSpace import settings +from search.documents.analysers import html_strip, ngram_analyser + +if TYPE_CHECKING: + from api.models import CollaborativeOrganizationRelationship as RelationshipModel + from api.models import Dataset as DatasetModel + from api.models import Organization as OrganizationModel + from api.models import UseCase as UseCaseModel + +INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[__name__]) +INDEX.settings(number_of_shards=1, number_of_replicas=0) + + +@INDEX.doc_type +class CollaborativeDocument(Document): + """Elasticsearch document for Collaborative model.""" + + metadata = fields.NestedField( + properties={ + "value": KeywordField(multi=True), + "raw": KeywordField(multi=True), + "metadata_item": fields.ObjectField(properties={"label": KeywordField(multi=False)}), + } + ) + + datasets = fields.NestedField( + properties={ + "title": fields.TextField(analyzer=ngram_analyser), + "description": fields.TextField(analyzer=html_strip), + "slug": fields.KeywordField(), + } + ) + + use_cases = fields.NestedField( + properties={ + "title": fields.TextField(analyzer=ngram_analyser), + "summary": fields.TextField(analyzer=html_strip), + "slug": fields.KeywordField(), + } + ) + + title = fields.TextField( + analyzer=ngram_analyser, + fields={ + "raw": KeywordField(multi=False), + }, + ) + + summary = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + + logo = fields.TextField(analyzer=ngram_analyser) + cover_image = fields.TextField(analyzer=ngram_analyser) + + status = fields.KeywordField() + slug = fields.KeywordField() + + tags = fields.TextField( + attr="tags_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + sectors = fields.TextField( + attr="sectors_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + geographies = fields.TextField( + attr="geographies_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + organization = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "logo": fields.TextField(analyzer=ngram_analyser), + } + ) + + user = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "bio": fields.TextField(analyzer=html_strip), + "profile_picture": fields.TextField(analyzer=ngram_analyser), + } + ) + + contributors = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "bio": fields.TextField(analyzer=html_strip), + "profile_picture": fields.TextField(analyzer=ngram_analyser), + } + ) + + organizations = fields.NestedField( + properties={ + "name": fields.TextField( + analyzer=ngram_analyser, fields={"raw": fields.KeywordField()} + ), + "logo": fields.TextField(analyzer=ngram_analyser), + "relationship_type": fields.KeywordField(), + } + ) + + is_individual_collaborative = fields.BooleanField(attr="is_individual_collaborative") + + website = fields.TextField(analyzer=ngram_analyser) + contact_email = fields.KeywordField() + platform_url = fields.TextField(analyzer=ngram_analyser) + started_on = fields.DateField() + completed_on = fields.DateField() + dataset_count = fields.IntegerField() + + def prepare_metadata(self, instance: Collaborative) -> List[Dict[str, Any]]: + processed_metadata: List[Dict[str, Any]] = [] + for meta in instance.metadata.all(): # type: CollaborativeMetadata + if not meta.metadata_item: + continue + + value_list = ( + [val.strip() for val in meta.value.split(",")] + if isinstance(meta.value, str) and "," in meta.value + else [meta.value] + ) + processed_metadata.append( + { + "value": value_list, + "metadata_item": {"label": meta.metadata_item.label}, + } + ) + return processed_metadata + + def prepare_datasets(self, instance: Collaborative) -> List[Dict[str, str]]: + datasets_data: List[Dict[str, str]] = [] + for dataset in instance.datasets.all(): + datasets_data.append( + { + "title": dataset.title or "", # type: ignore[attr-defined] + "description": dataset.description or "", # type: ignore[attr-defined] + "slug": dataset.slug or "", # type: ignore[attr-defined] + } + ) + return datasets_data + + def prepare_use_cases(self, instance: Collaborative) -> List[Dict[str, str]]: + use_cases_data: List[Dict[str, str]] = [] + for use_case in instance.use_cases.all(): + use_cases_data.append( + { + "title": use_case.title or "", # type: ignore[attr-defined] + "summary": use_case.summary or "", # type: ignore[attr-defined] + "slug": use_case.slug or "", # type: ignore[attr-defined] + } + ) + return use_cases_data + + def prepare_organization(self, instance: Collaborative) -> Optional[Dict[str, str]]: + if instance.organization: + org = instance.organization + logo_url = org.logo.url if org.logo else "" + return {"name": org.name, "logo": logo_url} + return None + + def prepare_user(self, instance: Collaborative) -> Optional[Dict[str, str]]: + if instance.user: + return { + "name": instance.user.full_name, + "bio": instance.user.bio or "", + "profile_picture": ( + instance.user.profile_picture.url if instance.user.profile_picture else "" + ), + } + return None + + def prepare_contributors(self, instance: Collaborative) -> List[Dict[str, str]]: + contributors_data: List[Dict[str, str]] = [] + for contributor in instance.contributors.all(): + contributors_data.append( + { + "name": contributor.full_name, # type: ignore + "bio": contributor.bio or "", # type: ignore + "profile_picture": ( + contributor.profile_picture.url # type: ignore + if contributor.profile_picture # type: ignore + else "" + ), + } + ) + return contributors_data + + def prepare_organizations(self, instance: Collaborative) -> List[Dict[str, str]]: + organizations_data: List[Dict[str, str]] = [] + relationships = CollaborativeOrganizationRelationship.objects.filter(collaborative=instance) + for relationship in relationships: + org = relationship.organization # type: ignore[attr-defined] + logo_url = org.logo.url if org.logo else "" + organizations_data.append( + { + "name": org.name, # type: ignore[attr-defined] + "logo": logo_url, + "relationship_type": relationship.relationship_type, # type: ignore[attr-defined] + } + ) + return organizations_data + + def prepare_logo(self, instance: Collaborative) -> str: + if instance.logo: + return str(instance.logo.path.replace("/code/files/", "")) + return "" + + def prepare_cover_image(self, instance: Collaborative) -> str: + if instance.cover_image: + return str(instance.cover_image.path.replace("/code/files/", "")) + return "" + + def prepare_dataset_count(self, instance: Collaborative) -> int: + return instance.datasets.count() + + def should_index_object(self, obj: Collaborative) -> bool: + return obj.status == CollaborativeStatus.PUBLISHED + + def save(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - thin wrapper + if self.status == CollaborativeStatus.PUBLISHED: + super().save(*args, **kwargs) + else: + self.delete(ignore=404) + + def delete(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - thin wrapper + super().delete(*args, **kwargs) + + def get_queryset(self) -> Any: + return ( + super(CollaborativeDocument, self) + .get_queryset() + .filter(status=CollaborativeStatus.PUBLISHED) + ) + + def get_instances_from_related( + self, + related_instance: Union[ + Dataset, + UseCase, + Metadata, + CollaborativeMetadata, + Sector, + Organization, + User, + Geography, + ], + ) -> Optional[Union[Collaborative, List[Collaborative]]]: + if isinstance(related_instance, Dataset): + return list(related_instance.collaborative_set.all()) # type: ignore[attr-defined] + if isinstance(related_instance, UseCase): + return list(related_instance.collaborative_set.all()) # type: ignore[attr-defined] + if isinstance(related_instance, Metadata): + collab_metadata_objects = related_instance.collaborativemetadata_set.all() # type: ignore[attr-defined] + return [obj.collaborative for obj in collab_metadata_objects] # type: ignore[attr-defined] + if isinstance(related_instance, CollaborativeMetadata): + return related_instance.collaborative # type: ignore[attr-defined] + if isinstance(related_instance, Sector): + return list(related_instance.collaboratives.all()) + if isinstance(related_instance, Organization): + primary = list(related_instance.collaborative_set.all()) + related = list(related_instance.related_collaboratives.all()) + return primary + related + if isinstance(related_instance, User): + owned = list(related_instance.collaborative_set.all()) + contributed = list(related_instance.contributed_collaboratives.all()) + return owned + contributed + if isinstance(related_instance, Geography): + return list(related_instance.collaboratives.all()) + return None + + class Django: + model = Collaborative + + fields = [ + "id", + "created", + "modified", + ] + + related_models = [ + Dataset, + UseCase, + Metadata, + CollaborativeMetadata, + Sector, + Organization, + User, + Geography, + ] diff --git a/search/documents/dataset_document.py b/search/documents/dataset_document.py index 8578fb4f..83f8d6e7 100644 --- a/search/documents/dataset_document.py +++ b/search/documents/dataset_document.py @@ -9,10 +9,11 @@ Geography, Metadata, Organization, + PromptDataset, Resource, Sector, ) -from api.utils.enums import DatasetStatus +from api.utils.enums import DatasetStatus, DatasetType from authorization.models import User from DataSpace import settings from search.documents.analysers import html_strip, ngram_analyser @@ -29,9 +30,7 @@ class DatasetDocument(Document): properties={ "value": KeywordField(multi=True), "raw": KeywordField(multi=True), - "metadata_item": fields.ObjectField( - properties={"label": KeywordField(multi=False)} - ), + "metadata_item": fields.ObjectField(properties={"label": KeywordField(multi=False)}), } ) @@ -127,10 +126,22 @@ class DatasetDocument(Document): is_individual_dataset = fields.BooleanField(attr="is_individual_dataset") + dataset_type = fields.KeywordField() + has_charts = fields.BooleanField(attr="has_charts") download_count = fields.IntegerField(attr="download_count") trending_score = fields.FloatField(attr="trending_score") + # Prompt-specific metadata (nested object, only populated for PROMPT type datasets) + prompt_metadata = fields.NestedField( + properties={ + "task_type": KeywordField(), + "target_languages": KeywordField(multi=True), + "domain": KeywordField(), + "target_model_types": KeywordField(multi=True), + } + ) + def prepare_metadata(self, instance: Dataset) -> List[Dict[str, Any]]: """Preprocess comma-separated metadata values into arrays.""" processed_metadata: List[Dict[str, Any]] = [] @@ -167,13 +178,30 @@ def prepare_user(self, instance: Dataset) -> Optional[Dict[str, str]]: "name": instance.user.full_name, "bio": instance.user.bio or "", "profile_picture": ( - instance.user.profile_picture.url - if instance.user.profile_picture - else "" + instance.user.profile_picture.url if instance.user.profile_picture else "" ), } return None + def prepare_prompt_metadata(self, instance: Dataset) -> Optional[Dict[str, Any]]: + """Prepare prompt metadata for indexing (only for PROMPT type datasets).""" + if instance.dataset_type != DatasetType.PROMPT: + return None + + try: + # With multi-table inheritance, check if this Dataset has a PromptDataset child + prompt_dataset = PromptDataset.objects.filter(dataset_ptr_id=instance.id).first() + if prompt_dataset: + return { + "task_type": prompt_dataset.task_type, + "target_languages": prompt_dataset.target_languages or [], + "domain": prompt_dataset.domain, + "target_model_types": prompt_dataset.target_model_types or [], + } + except PromptDataset.DoesNotExist: + pass + return None + def should_index_object(self, obj: Dataset) -> bool: """Check if the object should be indexed.""" return obj.status == DatasetStatus.PUBLISHED @@ -191,11 +219,7 @@ def delete(self, *args: Any, **kwargs: Any) -> None: def get_queryset(self) -> Any: """Get the queryset for indexing.""" - return ( - super(DatasetDocument, self) - .get_queryset() - .filter(status=DatasetStatus.PUBLISHED) - ) + return super(DatasetDocument, self).get_queryset().filter(status=DatasetStatus.PUBLISHED) def get_instances_from_related( self, @@ -203,6 +227,7 @@ def get_instances_from_related( Resource, Metadata, DatasetMetadata, + PromptDataset, Sector, Organization, User, @@ -218,6 +243,9 @@ def get_instances_from_related( return [obj.dataset for obj in ds_metadata_objects] # type: ignore elif isinstance(related_instance, DatasetMetadata): return related_instance.dataset + elif isinstance(related_instance, PromptDataset): + # PromptDataset IS a Dataset (multi-table inheritance), cast to Dataset + return Dataset.objects.get(pk=related_instance.pk) elif isinstance(related_instance, Sector): return list(related_instance.datasets.all()) elif isinstance(related_instance, Organization): @@ -245,6 +273,7 @@ class Django: Resource, Metadata, DatasetMetadata, + PromptDataset, Sector, Organization, User, diff --git a/search/documents/publisher_document.py b/search/documents/publisher_document.py new file mode 100644 index 00000000..b63744d1 --- /dev/null +++ b/search/documents/publisher_document.py @@ -0,0 +1,377 @@ +from typing import Any, Dict, List, Optional, Union + +from django_elasticsearch_dsl import Document, Index, KeywordField, fields + +from api.models import Dataset, Organization, Sector, UseCase +from api.utils.enums import DatasetStatus, UseCaseStatus +from authorization.models import User +from DataSpace import settings +from search.documents.analysers import html_strip, ngram_analyser + +# Create separate indices for each publisher document type +ORG_INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.OrganizationPublisherDocument"]) +ORG_INDEX.settings(number_of_shards=1, number_of_replicas=0) + +USER_INDEX = Index(settings.ELASTICSEARCH_INDEX_NAMES[f"{__name__}.UserPublisherDocument"]) +USER_INDEX.settings(number_of_shards=1, number_of_replicas=0) + + +class PublisherDocument(Document): + """Elasticsearch document for Publisher (Organization and User) models.""" + + # Common fields for both organizations and users + name = fields.TextField( + analyzer=ngram_analyser, + fields={ + "raw": KeywordField(multi=False), + }, + ) + + description = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + + publisher_type = fields.KeywordField() # 'organization' or 'user' + + # Organization specific fields + logo = fields.TextField(analyzer=ngram_analyser) + homepage = fields.TextField(analyzer=ngram_analyser) + contact_email = fields.KeywordField() + organization_types = fields.KeywordField() + github_profile = fields.TextField(analyzer=ngram_analyser) + linkedin_profile = fields.TextField(analyzer=ngram_analyser) + twitter_profile = fields.TextField(analyzer=ngram_analyser) + location = fields.TextField(analyzer=ngram_analyser) + + # User specific fields + bio = fields.TextField( + analyzer=html_strip, + fields={ + "raw": fields.TextField(analyzer="keyword"), + }, + ) + profile_picture = fields.TextField(analyzer=ngram_analyser) + username = fields.KeywordField() + email = fields.KeywordField() + first_name = fields.TextField(analyzer=ngram_analyser) + last_name = fields.TextField(analyzer=ngram_analyser) + full_name = fields.TextField(analyzer=ngram_analyser) + + # Common metadata + slug = fields.KeywordField() + + # Computed fields + published_datasets_count = fields.IntegerField() + published_usecases_count = fields.IntegerField() + members_count = fields.IntegerField() # Only for organizations + contributed_sectors_count = fields.IntegerField() + + # For search and filtering + sectors = fields.TextField( + attr="sectors_indexing", + analyzer=ngram_analyser, + fields={ + "raw": fields.KeywordField(multi=True), + "suggest": fields.CompletionField(multi=True), + }, + multi=True, + ) + + def prepare_name(self, instance: Union[Organization, User]) -> str: + """Prepare name field for indexing.""" + if isinstance(instance, Organization): + return getattr(instance, "name", "") + else: # User + return getattr(instance, "full_name", "") or getattr(instance, "username", "") + + def prepare_description(self, instance: Union[Organization, User]) -> str: + """Prepare description field for indexing.""" + if isinstance(instance, Organization): + return instance.description or "" + else: # User + return instance.bio or "" + + def prepare_publisher_type(self, instance: Union[Organization, User]) -> str: + """Determine publisher type.""" + return "organization" if isinstance(instance, Organization) else "user" + + def prepare_logo(self, instance: Union[Organization, User]) -> str: + """Prepare logo/profile picture URL.""" + if isinstance(instance, Organization): + logo = getattr(instance, "logo", None) + return str(logo.url) if logo and hasattr(logo, "url") else "" + else: # User + profile_picture = getattr(instance, "profile_picture", None) + return ( + str(profile_picture.url) + if profile_picture and hasattr(profile_picture, "url") + else "" + ) + + def prepare_profile_picture(self, instance: Union[Organization, User]) -> str: + """Prepare profile picture URL for users.""" + if isinstance(instance, User): + profile_picture = getattr(instance, "profile_picture", None) + return ( + str(profile_picture.url) + if profile_picture and hasattr(profile_picture, "url") + else "" + ) + return "" + + def prepare_slug(self, instance: Union[Organization, User]) -> str: + """Prepare slug field.""" + if isinstance(instance, Organization): + return getattr(instance, "slug", "") or "" + else: # User + return str(getattr(instance, "id", "")) # Users don't have slugs, use ID + + def prepare_full_name(self, instance: Union[Organization, User]) -> str: + """Prepare full name for users.""" + if isinstance(instance, User): + first_name = getattr(instance, "first_name", "") + last_name = getattr(instance, "last_name", "") + if first_name and last_name: + return f"{first_name} {last_name}" + elif first_name: + return first_name + elif last_name: + return last_name + else: + return getattr(instance, "username", "") + return "" + + def prepare_published_datasets_count(self, instance: Union[Organization, User]) -> int: + """Get count of published datasets.""" + try: + if isinstance(instance, Organization): + return Dataset.objects.filter( + organization_id=instance.id, status=DatasetStatus.PUBLISHED.value + ).count() + else: # User + return Dataset.objects.filter( + user_id=instance.id, status=DatasetStatus.PUBLISHED.value + ).count() + except Exception: + return 0 + + def prepare_published_usecases_count(self, instance: Union[Organization, User]) -> int: + """Get count of published use cases.""" + try: + if isinstance(instance, Organization): + from django.db.models import Q + + use_cases = UseCase.objects.filter( + ( + Q(organization__id=instance.id) + | Q(usecaseorganizationrelationship__organization_id=instance.id) + ), + status=UseCaseStatus.PUBLISHED.value, + ).distinct() + return use_cases.count() + else: # User + return UseCase.objects.filter( + user_id=instance.id, status=UseCaseStatus.PUBLISHED.value + ).count() + except Exception: + return 0 + + def prepare_members_count(self, instance: Union[Organization, User]) -> int: + """Get count of members (only for organizations).""" + if isinstance(instance, Organization): + try: + from authorization.models import OrganizationMembership + + return OrganizationMembership.objects.filter(organization_id=instance.id).count() + except Exception: + return 0 + return 0 + + def prepare_contributed_sectors_count(self, instance: Union[Organization, User]) -> int: + """Get count of sectors contributed to.""" + try: + from api.models import Sector + + if isinstance(instance, Organization): + # Get sectors from published datasets + dataset_sectors = ( + Sector.objects.filter( + datasets__organization_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Get sectors from published use cases + usecase_sectors = ( + Sector.objects.filter( + usecases__usecaseorganizationrelationship__organization_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + else: # User + # Get sectors from published datasets + dataset_sectors = ( + Sector.objects.filter( + datasets__user_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Get sectors from published use cases + usecase_sectors = ( + Sector.objects.filter( + usecases__user_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ) + .values_list("id", flat=True) + .distinct() + ) + + # Combine and deduplicate sectors + sector_ids = set(dataset_sectors) + sector_ids.update(usecase_sectors) + + return len(sector_ids) + except Exception: + return 0 + + def prepare_sectors_indexing(self, instance: Union[Organization, User]) -> List[str]: + """Prepare sectors for indexing.""" + try: + from api.models import Sector + + if isinstance(instance, Organization): + # Get sectors from published datasets + dataset_sectors = Sector.objects.filter( + datasets__organization_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ).distinct() + + # Get sectors from published use cases + usecase_sectors = Sector.objects.filter( + usecases__usecaseorganizationrelationship__organization_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ).distinct() + else: # User + # Get sectors from published datasets + dataset_sectors = Sector.objects.filter( + datasets__user_id=instance.id, + datasets__status=DatasetStatus.PUBLISHED.value, + ).distinct() + + # Get sectors from published use cases + usecase_sectors = Sector.objects.filter( + usecases__user_id=instance.id, + usecases__status=UseCaseStatus.PUBLISHED.value, + ).distinct() + + # Combine and deduplicate sectors + all_sectors = set(dataset_sectors) | set(usecase_sectors) + return [ + getattr(sector, "name", "") for sector in all_sectors if hasattr(sector, "name") + ] + except Exception: + return [] + + def should_index_object(self, obj: Union[Organization, User]) -> bool: + """Check if the object should be indexed (has published content).""" + try: + if isinstance(obj, Organization): + has_datasets = Dataset.objects.filter( + organization_id=obj.id, status=DatasetStatus.PUBLISHED.value + ).exists() + has_usecases = UseCase.objects.filter( + organization_id=obj.id, status=UseCaseStatus.PUBLISHED.value + ).exists() + return has_datasets or has_usecases + else: # User + has_datasets = Dataset.objects.filter( + user_id=obj.id, status=DatasetStatus.PUBLISHED.value + ).exists() + has_usecases = UseCase.objects.filter( + user_id=obj.id, status=UseCaseStatus.PUBLISHED.value + ).exists() + return has_datasets or has_usecases + except Exception: + return False + + def get_instances_from_related( + self, + related_instance: Union[Dataset, UseCase], + ) -> Optional[List[Union[Organization, User]]]: + """Get Publisher instances from related models.""" + publishers: List[Union[Organization, User]] = [] + + if isinstance(related_instance, Dataset): + if related_instance.organization: + publishers.append(related_instance.organization) + if related_instance.user: + publishers.append(related_instance.user) + elif isinstance(related_instance, UseCase): + if related_instance.organization: + publishers.append(related_instance.organization) + if related_instance.user: + publishers.append(related_instance.user) + + return publishers if publishers else None + + +@ORG_INDEX.doc_type +class OrganizationPublisherDocument(PublisherDocument): + """Organization-specific publisher document.""" + + class Django: + """Django model configuration.""" + + model = Organization + + fields = [ + "id", + "created", + "modified", + ] + + related_models = [ + Dataset, + UseCase, + Sector, + ] + + +@USER_INDEX.doc_type +class UserPublisherDocument(PublisherDocument): + """User-specific publisher document.""" + + class Django: + """Django model configuration.""" + + model = User + + fields = [ + "id", + "date_joined", + "last_login", + ] + + related_models = [ + Dataset, + UseCase, + Sector, + ] + + def prepare_created(self, instance: User) -> Any: + """Map date_joined to created for consistency.""" + return instance.date_joined + + def prepare_modified(self, instance: User) -> Any: + """Map last_login to modified for consistency.""" + return instance.last_login diff --git a/setup.py b/setup.py index 207fc45f..57603549 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,23 @@ """Setup configuration for DataSpace Python SDK.""" +import os +from typing import Any, Dict + from setuptools import find_packages, setup +# Read version from __version__.py +version: Dict[str, Any] = {} +version_file = os.path.join(os.path.dirname(__file__), "dataspace_sdk", "__version__.py") +with open(version_file, "r", encoding="utf-8") as f: + exec(f.read(), version) + with open("docs/sdk/README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setup( name="dataspace-sdk", - version="0.4.18", + + version=version["__version__"], author="CivicDataLab", author_email="tech@civicdatalab.in", description="Python SDK for DataSpace API - programmatic access to datasets, AI models, and use cases", diff --git a/tests/test_aimodels.py b/tests/test_aimodels.py index a184f3ea..b5bcd436 100644 --- a/tests/test_aimodels.py +++ b/tests/test_aimodels.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(AIModelClient, "get") - def test_search_models(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_search_models(self, mock_request: MagicMock) -> None: """Test AI model search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 5, "results": [{"id": "1", "displayName": "Test Model", "modelType": "LLM"}], } @@ -33,12 +33,12 @@ def test_search_models(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 5) self.assertEqual(len(result["results"]), 1) self.assertEqual(result["results"][0]["displayName"], "Test Model") - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "get") - def test_get_model_by_id(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_get_model_by_id(self, mock_request: MagicMock) -> None: """Test get AI model by ID.""" - mock_get.return_value = { + mock_request.return_value = { "id": "123", "displayName": "Test Model", "modelType": "LLM", @@ -49,14 +49,14 @@ def test_get_model_by_id(self, mock_get: MagicMock) -> None: self.assertEqual(result["id"], "123") self.assertEqual(result["displayName"], "Test Model") - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(AIModelClient, "post") def test_get_model_by_id_graphql(self, mock_post: MagicMock) -> None: """Test get AI model by ID using GraphQL.""" mock_post.return_value = { "data": { - "aiModel": { + "getAiModel": { "id": "123", "displayName": "Test Model", "description": "A test model", @@ -90,33 +90,136 @@ def test_get_organization_models(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(AIModelClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test AI model search with filters.""" - mock_get.return_value = {"total": 3, "results": []} + mock_request.return_value = {"total": 3, "results": []} result = self.client.search( query="language", tags=["nlp"], sectors=["tech"], + status="ACTIVE", model_type="LLM", provider="OpenAI", - status="ACTIVE", ) self.assertEqual(result["total"], 3) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(AIModelClient, "post") - def test_graphql_error_handling(self, mock_post: MagicMock) -> None: + @patch.object(AIModelClient, "_make_request") + def test_graphql_error_handling(self, mock_request: MagicMock) -> None: """Test GraphQL error handling.""" from dataspace_sdk.exceptions import DataSpaceAPIError - mock_post.return_value = {"errors": [{"message": "GraphQL error"}]} + mock_request.return_value = {"errors": [{"message": "GraphQL error"}]} with self.assertRaises(DataSpaceAPIError): self.client.get_by_id_graphql("123") + @patch.object(AIModelClient, "_make_request") + def test_call_model(self, mock_request: MagicMock) -> None: + """Test calling an AI model.""" + mock_request.return_value = { + "success": True, + "output": "Paris is the capital of France.", + "latency_ms": 150, + "provider": "OpenAI", + } + + result = self.client.call_model( + model_id="123", + input_text="What is the capital of France?", + parameters={"temperature": 0.7, "max_tokens": 100}, + ) + + self.assertTrue(result["success"]) + self.assertEqual(result["output"], "Paris is the capital of France.") + self.assertEqual(result["latency_ms"], 150) + mock_request.assert_called_once() + + @patch.object(AIModelClient, "_make_request") + def test_call_model_async(self, mock_request: MagicMock) -> None: + """Test calling an AI model asynchronously.""" + mock_request.return_value = { + "task_id": "task-456", + "status": "PENDING", + "created_at": "2024-01-01T00:00:00Z", + } + + result = self.client.call_model_async( + model_id="123", + input_text="Generate a long document", + parameters={"max_tokens": 2000}, + ) + + self.assertEqual(result["task_id"], "task-456") + self.assertEqual(result["status"], "PENDING") + mock_request.assert_called_once() + + @patch.object(AIModelClient, "_make_request") + def test_call_model_error(self, mock_request: MagicMock) -> None: + """Test AI model call with error.""" + mock_request.return_value = { + "success": False, + "error": "Model not available", + } + + result = self.client.call_model( + model_id="123", + input_text="Test input", + ) + + self.assertFalse(result["success"]) + self.assertEqual(result["error"], "Model not available") + + @patch.object(AIModelClient, "_make_request") + def test_create_model(self, mock_request: MagicMock) -> None: + """Test creating an AI model.""" + mock_request.return_value = { + "id": "new-model-123", + "displayName": "New Model", + "modelType": "LLM", + } + + result = self.client.create( + { + "displayName": "New Model", + "modelType": "LLM", + "provider": "OpenAI", + } + ) + + self.assertEqual(result["id"], "new-model-123") + self.assertEqual(result["displayName"], "New Model") + mock_request.assert_called_once() + + @patch.object(AIModelClient, "_make_request") + def test_update_model(self, mock_request: MagicMock) -> None: + """Test updating an AI model.""" + mock_request.return_value = { + "id": "123", + "displayName": "Updated Model", + "description": "Updated description", + } + + result = self.client.update( + "123", {"displayName": "Updated Model", "description": "Updated description"} + ) + + self.assertEqual(result["displayName"], "Updated Model") + mock_request.assert_called_once() + + @patch.object(AIModelClient, "_make_request") + def test_delete_model(self, mock_request: MagicMock) -> None: + """Test deleting an AI model.""" + mock_request.return_value = {"message": "Model deleted successfully"} + + result = self.client.delete_model("123") + + self.assertEqual(result["message"], "Model deleted successfully") + mock_request.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_auth.py b/tests/test_auth.py index c05b502f..cf0212c4 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -13,18 +13,60 @@ class TestAuthClient(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" self.base_url = "https://api.test.com" - self.auth_client = AuthClient(self.base_url) + self.keycloak_url = "https://keycloak.test.com" + self.keycloak_realm = "test-realm" + self.keycloak_client_id = "test-client" + self.auth_client = AuthClient( + self.base_url, + keycloak_url=self.keycloak_url, + keycloak_realm=self.keycloak_realm, + keycloak_client_id=self.keycloak_client_id, + ) def test_init(self) -> None: """Test AuthClient initialization.""" self.assertEqual(self.auth_client.base_url, self.base_url) + self.assertEqual(self.auth_client.keycloak_url, self.keycloak_url) + self.assertEqual(self.auth_client.keycloak_realm, self.keycloak_realm) + self.assertEqual(self.auth_client.keycloak_client_id, self.keycloak_client_id) self.assertIsNone(self.auth_client.access_token) self.assertIsNone(self.auth_client.refresh_token) self.assertIsNone(self.auth_client.user_info) @patch("dataspace_sdk.auth.requests.post") - def test_login_success(self, mock_post: MagicMock) -> None: - """Test successful login.""" + def test_login_with_username_password(self, mock_post: MagicMock) -> None: + """Test successful login with username/password.""" + # Mock Keycloak token response + keycloak_response = MagicMock() + keycloak_response.status_code = 200 + keycloak_response.json.return_value = { + "access_token": "keycloak_access_token", + "refresh_token": "keycloak_refresh_token", + "expires_in": 300, + } + + # Mock DataSpace backend login response + backend_response = MagicMock() + backend_response.status_code = 200 + backend_response.json.return_value = { + "access": "test_access_token", + "refresh": "test_refresh_token", + "user": {"id": "123", "username": "testuser"}, + } + + mock_post.side_effect = [keycloak_response, backend_response] + + result = self.auth_client.login("testuser", "password") + + self.assertEqual(self.auth_client.access_token, "test_access_token") + self.assertEqual(self.auth_client.refresh_token, "test_refresh_token") + self.assertIsNotNone(self.auth_client.user_info) + self.assertEqual(result["user"]["username"], "testuser") + self.assertEqual(mock_post.call_count, 2) + + @patch("dataspace_sdk.auth.requests.post") + def test_login_with_token(self, mock_post: MagicMock) -> None: + """Test successful login with Keycloak token.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -34,7 +76,7 @@ def test_login_success(self, mock_post: MagicMock) -> None: } mock_post.return_value = mock_response - result = self.auth_client.login_with_keycloak("test_keycloak_token") + result = self.auth_client._login_with_keycloak_token("test_keycloak_token") self.assertEqual(self.auth_client.access_token, "test_access_token") self.assertEqual(self.auth_client.refresh_token, "test_refresh_token") @@ -46,11 +88,11 @@ def test_login_failure(self, mock_post: MagicMock) -> None: """Test failed login.""" mock_response = MagicMock() mock_response.status_code = 401 - mock_response.json.return_value = {"error": "Invalid token"} + mock_response.json.return_value = {"error": "invalid_grant"} mock_post.return_value = mock_response with self.assertRaises(DataSpaceAuthError): - self.auth_client.login_with_keycloak("invalid_token") + self.auth_client.login("invalid_user", "invalid_password") @patch("dataspace_sdk.auth.requests.post") def test_refresh_token_success(self, mock_post: MagicMock) -> None: @@ -99,6 +141,51 @@ def test_is_authenticated(self) -> None: self.auth_client.access_token = "test_token" self.assertTrue(self.auth_client.is_authenticated()) + @patch("dataspace_sdk.auth.requests.post") + def test_login_as_service_account(self, mock_post: MagicMock) -> None: + """Test successful service account login.""" + # Create auth client with client secret + auth_client = AuthClient( + self.base_url, + keycloak_url=self.keycloak_url, + keycloak_realm=self.keycloak_realm, + keycloak_client_id=self.keycloak_client_id, + keycloak_client_secret="test_secret", + ) + + # Mock Keycloak token response + keycloak_response = MagicMock() + keycloak_response.status_code = 200 + keycloak_response.json.return_value = { + "access_token": "service_access_token", + "refresh_token": "service_refresh_token", + "expires_in": 300, + } + + # Mock DataSpace backend login response + backend_response = MagicMock() + backend_response.status_code = 200 + backend_response.json.return_value = { + "access": "test_access_token", + "refresh": "test_refresh_token", + "user": {"id": "service-123", "username": "service-account"}, + } + + mock_post.side_effect = [keycloak_response, backend_response] + + result = auth_client.login_as_service_account() + + self.assertEqual(auth_client.access_token, "test_access_token") + self.assertEqual(auth_client.refresh_token, "test_refresh_token") + self.assertIsNotNone(auth_client.user_info) + self.assertEqual(result["user"]["username"], "service-account") + self.assertEqual(mock_post.call_count, 2) + + def test_login_as_service_account_no_secret(self) -> None: + """Test service account login without client secret.""" + with self.assertRaises(DataSpaceAuthError): + self.auth_client.login_as_service_account() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_client.py b/tests/test_client.py index 870a94f5..7d27d44d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -22,15 +22,28 @@ def test_init(self) -> None: self.assertIsNotNone(self.client.aimodels) self.assertIsNotNone(self.client.usecases) - @patch("dataspace_sdk.client.AuthClient.login_with_keycloak") + @patch("dataspace_sdk.client.AuthClient.login") def test_login(self, mock_login: MagicMock) -> None: - """Test login method.""" + """Test login method with username/password.""" mock_login.return_value = { "access": "token", "user": {"id": "123", "username": "testuser"}, } - result = self.client.login("test_token") + result = self.client.login("testuser", "password") + + self.assertEqual(result["user"]["username"], "testuser") + mock_login.assert_called_once_with("testuser", "password") + + @patch("dataspace_sdk.client.AuthClient._login_with_keycloak_token") + def test_login_with_token(self, mock_login: MagicMock) -> None: + """Test login method with Keycloak token.""" + mock_login.return_value = { + "access": "token", + "user": {"id": "123", "username": "testuser"}, + } + + result = self.client.login_with_token("test_token") self.assertEqual(result["user"]["username"], "testuser") mock_login.assert_called_once_with("test_token") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4219cb4c..74653975 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(DatasetClient, "get") - def test_search_datasets(self, mock_get: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_search_datasets(self, mock_request: MagicMock) -> None: """Test dataset search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 10, "results": [{"id": "1", "title": "Test Dataset"}], } @@ -32,22 +32,22 @@ def test_search_datasets(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 10) self.assertEqual(len(result["results"]), 1) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(DatasetClient, "post") - def test_get_dataset_by_id(self, mock_post: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_get_dataset_by_id(self, mock_request: MagicMock) -> None: """Test get dataset by ID.""" - mock_post.return_value = {"data": {"dataset": {"id": "123", "title": "Test Dataset"}}} + mock_request.return_value = {"data": {"getDataset": {"id": "123", "title": "Test Dataset"}}} result = self.client.get_by_id("123") self.assertEqual(result["id"], "123") self.assertEqual(result["title"], "Test Dataset") - @patch.object(DatasetClient, "post") - def test_list_all_datasets(self, mock_post: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_list_all_datasets(self, mock_request: MagicMock) -> None: """Test list all datasets.""" - mock_post.return_value = {"data": {"datasets": [{"id": "1", "title": "Dataset 1"}]}} + mock_request.return_value = {"data": {"datasets": [{"id": "1", "title": "Dataset 1"}]}} result = self.client.list_all(limit=10, offset=0) @@ -73,10 +73,10 @@ def test_get_organization_datasets(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(DatasetClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(DatasetClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test dataset search with filters.""" - mock_get.return_value = {"total": 5, "results": []} + mock_request.return_value = {"total": 5, "results": []} result = self.client.search( query="health", @@ -87,7 +87,57 @@ def test_search_with_filters(self, mock_get: MagicMock) -> None: ) self.assertEqual(result["total"], 5) - mock_get.assert_called_once() + mock_request.assert_called_once() + + @patch.object(DatasetClient, "_make_request") + def test_get_dataset_with_resources(self, mock_request: MagicMock) -> None: + """Test get dataset by ID which includes resources.""" + mock_request.return_value = { + "data": { + "getDataset": { + "id": "dataset-123", + "title": "Test Dataset", + "resources": [ + { + "id": "res-1", + "title": "Resource 1", + "fileDetails": {"format": "CSV"}, + } + ], + } + } + } + + result = self.client.get_by_id("dataset-123") + + self.assertEqual(result["id"], "dataset-123") + self.assertEqual(len(result["resources"]), 1) + self.assertEqual(result["resources"][0]["title"], "Resource 1") + mock_request.assert_called_once() + + @patch.object(DatasetClient, "list_all") + def test_get_organization_datasets(self, mock_list_all: MagicMock) -> None: + """Test get datasets by organization.""" + mock_list_all.return_value = [ + {"id": "1", "title": "Org Dataset 1"}, + {"id": "2", "title": "Org Dataset 2"}, + ] + + result = self.client.get_organization_datasets("org-123", limit=10) + + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + mock_list_all.assert_called_once_with(organization_id="org-123", limit=10, offset=0) + + @patch.object(DatasetClient, "_make_request") + def test_search_with_sorting(self, mock_request: MagicMock) -> None: + """Test dataset search with sorting.""" + mock_request.return_value = {"total": 3, "results": []} + + result = self.client.search(query="test", sort="recent", page=1, page_size=10) + + self.assertEqual(result["total"], 3) + mock_request.assert_called_once() if __name__ == "__main__": diff --git a/tests/test_usecases.py b/tests/test_usecases.py index 6d966afa..02d75f32 100644 --- a/tests/test_usecases.py +++ b/tests/test_usecases.py @@ -20,10 +20,10 @@ def test_init(self) -> None: self.assertEqual(self.client.base_url, self.base_url) self.assertEqual(self.client.auth_client, self.auth_client) - @patch.object(UseCaseClient, "get") - def test_search_usecases(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_usecases(self, mock_request: MagicMock) -> None: """Test use case search.""" - mock_get.return_value = { + mock_request.return_value = { "total": 8, "results": [ { @@ -40,7 +40,7 @@ def test_search_usecases(self, mock_get: MagicMock) -> None: self.assertEqual(result["total"], 8) self.assertEqual(len(result["results"]), 1) self.assertEqual(result["results"][0]["title"], "Test Use Case") - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(UseCaseClient, "post") def test_get_usecase_by_id(self, mock_post: MagicMock) -> None: @@ -82,10 +82,10 @@ def test_get_organization_usecases(self, mock_post: MagicMock) -> None: self.assertIsInstance(result, (list, dict)) mock_post.assert_called_once() - @patch.object(UseCaseClient, "get") - def test_search_with_filters(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_with_filters(self, mock_request: MagicMock) -> None: """Test use case search with filters.""" - mock_get.return_value = {"total": 4, "results": []} + mock_request.return_value = {"total": 4, "results": []} result = self.client.search( query="monitoring", @@ -96,17 +96,17 @@ def test_search_with_filters(self, mock_get: MagicMock) -> None: ) self.assertEqual(result["total"], 4) - mock_get.assert_called_once() + mock_request.assert_called_once() - @patch.object(UseCaseClient, "get") - def test_search_with_sorting(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_with_sorting(self, mock_request: MagicMock) -> None: """Test use case search with sorting.""" - mock_get.return_value = {"total": 2, "results": []} + mock_request.return_value = {"total": 2, "results": []} result = self.client.search(query="test", sort="completed_on", page=1, page_size=10) self.assertEqual(result["total"], 2) - mock_get.assert_called_once() + mock_request.assert_called_once() @patch.object(UseCaseClient, "post") def test_graphql_error_handling(self, mock_post: MagicMock) -> None: @@ -118,21 +118,17 @@ def test_graphql_error_handling(self, mock_post: MagicMock) -> None: with self.assertRaises(DataSpaceAPIError): self.client.get_by_id(123) - @patch.object(UseCaseClient, "get") - def test_search_pagination(self, mock_get: MagicMock) -> None: + @patch.object(UseCaseClient, "_make_request") + def test_search_pagination(self, mock_request: MagicMock) -> None: """Test use case search with pagination.""" - mock_get.return_value = { - "total": 50, - "page": 2, - "page_size": 20, - "results": [], - } + mock_request.return_value = {"total": 50, "results": [], "page": 2, "page_size": 20} result = self.client.search(query="test", page=2, page_size=20) self.assertEqual(result["total"], 50) self.assertEqual(result["page"], 2) - mock_get.assert_called_once() + self.assertEqual(result["page_size"], 20) + mock_request.assert_called_once() if __name__ == "__main__":