From 26e5d968f65ea3892544fa2237120ec123b47c7d Mon Sep 17 00:00:00 2001 From: Arondondon Date: Mon, 16 Mar 2026 18:38:43 +0300 Subject: [PATCH 1/7] feat: rework contracts calls, implement registry contract, enhance importing --- Makefile | 5 ++ pyproject.toml | 1 + snet/sdk/__init__.py | 53 +++---------- snet/sdk/account.py | 52 +++++-------- snet/sdk/client_lib_generator.py | 2 +- snet/sdk/mpe/mpe_contract.py | 40 ++++++---- snet/sdk/mpe/payment_channel.py | 4 +- snet/sdk/mpe/payment_channel_provider.py | 5 +- .../paidcall_payment_strategy.py | 3 +- .../prepaid_payment_strategy.py | 3 +- .../__init__.py | 0 snet/sdk/registry/registry_contract.py | 76 +++++++++++++++++++ .../service_metadata.py | 0 .../storage_provider.py | 25 ++---- snet/sdk/service_client.py | 6 +- snet/sdk/types.py | 1 + snet/sdk/utils/utils.py | 16 +++- tests/unit_tests/test_lib_generator.py | 2 +- tests/unit_tests/test_service_client.py | 2 +- 19 files changed, 171 insertions(+), 125 deletions(-) rename snet/sdk/{storage_provider => registry}/__init__.py (100%) create mode 100644 snet/sdk/registry/registry_contract.py rename snet/sdk/{storage_provider => registry}/service_metadata.py (100%) rename snet/sdk/{storage_provider => registry}/storage_provider.py (76%) create mode 100644 snet/sdk/types.py diff --git a/Makefile b/Makefile index 60994e2..4708723 100644 --- a/Makefile +++ b/Makefile @@ -2,3 +2,8 @@ lint: @ruff check . --fix @ruff format . .PHONY: lint + +test: + python -m coverage run -m pytest tests/ -v && \ + python -m coverage report +.PHONY: test \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f2c6551..2786d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ [tool.poetry.group.dev.dependencies] ruff = "^0.11" pytest = "^8.3" +coverage = "^7.13" [tool.ruff] line-length = 100 diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index e574fa0..04d5c2b 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -7,7 +7,8 @@ import google.protobuf.internal.api_implementation from google.protobuf import symbol_database as _symbol_database -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.registry_contract import RegistryContract +from snet.sdk.registry.service_metadata import MPEServiceMetadata with warnings.catch_warnings(): # Suppress the eth-typing package`s warnings related to some new networks @@ -18,9 +19,6 @@ UserWarning, ) - import web3 - -from snet.contracts import get_contract_object from snet.sdk.account import Account from snet.sdk.config import config from snet.sdk.client_lib_generator import ClientLibGenerator @@ -34,12 +32,11 @@ PaymentStrategy, ) from snet.sdk.service_client import ServiceClient -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider from snet.sdk.custom_typing import ModuleName, ServiceStub from snet.sdk.utils.utils import ( - bytes32_to_str, find_file_by_keyword, - type_converter, + get_we3_object, ) google.protobuf.internal.api_implementation.Type = lambda: "python" @@ -55,29 +52,13 @@ class PaymentStrategyType(Enum): class SnetSDK: - """Base Snet SDK""" - def __init__(self): - self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT)) - - mpe_contract_address = config.MPE_CONTRACT_ADDRESS - if not mpe_contract_address: - self.mpe_contract = MPEContract(self.web3) - else: - self.mpe_contract = MPEContract(self.web3, mpe_contract_address) - - registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS - if registry_contract_address is None: - self.registry_contract = get_contract_object(self.web3, "Registry") - else: - self.registry_contract = get_contract_object( - self.web3, "Registry", registry_contract_address - ) - + self.w3 = get_we3_object() + self.mpe_contract = MPEContract() + self.registry_contract = RegistryContract() self.metadata_provider = StorageProvider(self.registry_contract) - - self.account = Account(self.web3, self.mpe_contract) - self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) + self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract) + self.account = Account() def create_service_client( self, @@ -134,7 +115,7 @@ def create_service_client( options, self.mpe_contract, self.account, - self.web3, + self.w3, pb2_module, self.payment_channel_provider, lib_generator.protodir, @@ -190,17 +171,7 @@ def _get_service_group_details( return self._get_group_by_group_name(service_metadata, group_name) def get_organization_list(self) -> list: - org_list = self.registry_contract.functions.listOrganizations().call() - organization_list = [] - for idx, org_id in enumerate(org_list): - organization_list.append(bytes32_to_str(org_id)) - return organization_list + return self.registry_contract.list_orgs() def get_services_list(self, org_id: str) -> list: - found, org_service_list = self.registry_contract.functions.listServicesForOrganization( - type_converter("bytes32")(org_id) - ).call() - if not found: - raise Exception(f"Organization with id={org_id} doesn't exist!") - org_service_list = list(map(bytes32_to_str, org_service_list)) - return org_service_list + return self.registry_contract.list_service_for_org(org_id) diff --git a/snet/sdk/account.py b/snet/sdk/account.py index 31b404d..0003758 100644 --- a/snet/sdk/account.py +++ b/snet/sdk/account.py @@ -1,11 +1,10 @@ import json -import web3 from snet.contracts import get_contract_object +from snet.sdk import get_we3_object from snet.sdk.config import config -from snet.sdk.mpe.mpe_contract import MPEContract from snet.sdk.utils.utils import get_address_from_private, normalize_private_key DEFAULT_GAS = 300000 @@ -28,17 +27,13 @@ def __str__(self): class Account: - def __init__(self, w3: web3.Web3, mpe_contract: MPEContract): - self.web3 = w3 - self.mpe_contract = mpe_contract + def __init__(self): + self.w3 = get_we3_object() + self.mpe_address = config.MPE_CONTRACT_ADDRESS - token_contract_address = config.TOKEN_CONTRACT_ADDRESS - if not token_contract_address: - self.token_contract = get_contract_object(self.web3, "FetchToken") - else: - self.token_contract = get_contract_object( - self.web3, "FetchToken", token_contract_address - ) + self.token_contract = get_contract_object( + self.w3, "FetchToken", config.TOKEN_CONTRACT_ADDRESS + ) if config.PRIVATE_KEY: self.private_key = normalize_private_key(config.PRIVATE_KEY) @@ -52,19 +47,19 @@ def __init__(self, w3: web3.Web3, mpe_contract: MPEContract): self.nonce = 0 def _get_nonce(self): - nonce = self.web3.eth.get_transaction_count(self.address) + nonce = self.w3.eth.get_transaction_count(self.address) if self.nonce >= nonce: nonce = self.nonce + 1 self.nonce = nonce return nonce def _get_gas_price(self): - gas_price = self.web3.eth.gas_price + gas_price = self.w3.eth.gas_price if gas_price <= 15000000000: gas_price += gas_price * 1 / 3 - elif gas_price > 15000000000 and gas_price <= 50000000000: + elif 15000000000 < gas_price <= 50000000000: gas_price += gas_price * 1 / 5 - elif gas_price > 50000000000 and gas_price <= 150000000000: + elif 50000000000 < gas_price <= 150000000000: gas_price += 7000000000 elif gas_price > 150000000000: gas_price += gas_price * 1 / 10 @@ -73,20 +68,18 @@ def _get_gas_price(self): def _send_signed_transaction(self, contract_fn, *args): transaction = contract_fn(*args).build_transaction( { - "chainId": int(self.web3.net.version), + "chainId": int(self.w3.net.version), "gas": DEFAULT_GAS, "gasPrice": self._get_gas_price(), "nonce": self._get_nonce(), } ) - signed_txn = self.web3.eth.account.sign_transaction( - transaction, private_key=self.private_key - ) - return self.web3.to_hex(self.web3.eth.send_raw_transaction(signed_txn.raw_transaction)) + signed_txn = self.w3.eth.account.sign_transaction(transaction, private_key=self.private_key) + return self.w3.to_hex(self.w3.eth.send_raw_transaction(signed_txn.raw_transaction)) def send_transaction(self, contract_fn, *args): txn_hash = self._send_signed_transaction(contract_fn, *args) - return self.web3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT) + return self.w3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT) def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder): if receipt.status == 0: @@ -94,23 +87,12 @@ def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder): else: return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), cls=encoder) - def escrow_balance(self): - return self.mpe_contract.balance(self.address) - - def deposit_to_escrow_account(self, amount_in_cogs): - already_approved = self.allowance() - if amount_in_cogs > already_approved: - self.approve_transfer(amount_in_cogs) - return self.mpe_contract.deposit(self, amount_in_cogs) - def approve_transfer(self, amount_in_cogs): return self.send_transaction( self.token_contract.functions.approve, - self.mpe_contract.contract.address, + self.mpe_address, amount_in_cogs, ) def allowance(self): - return self.token_contract.functions.allowance( - self.address, self.mpe_contract.contract.address - ).call() + return self.token_contract.functions.allowance(self.address, self.mpe_address).call() diff --git a/snet/sdk/client_lib_generator.py b/snet/sdk/client_lib_generator.py index eb485c1..234571b 100644 --- a/snet/sdk/client_lib_generator.py +++ b/snet/sdk/client_lib_generator.py @@ -1,7 +1,7 @@ import os from pathlib import Path -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider from snet.sdk.utils.utils import compile_proto diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index 3823da1..d33ea5d 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -1,21 +1,29 @@ +from typing import Optional + from snet.contracts import get_contract_object +from snet.sdk import get_we3_object, config, Account + class MPEContract: - def __init__(self, w3, address=None): - self.web3 = w3 - if address is None: - self.contract = get_contract_object(self.web3, "MultiPartyEscrow") - else: - self.contract = get_contract_object(self.web3, "MultiPartyEscrow", address) - - def balance(self, address): + def __init__(self): + self.w3 = get_we3_object() + self.contract = get_contract_object( + self.w3, "MultiPartyEscrow", config.MPE_CONTRACT_ADDRESS + ) + + def balance(self, account: Account, address: Optional[str] = None): + if not address: + address = account.address return self.contract.functions.balances(address).call() def deposit(self, account, amount_in_cogs): + already_approved = account.allowance() + if amount_in_cogs > already_approved: + account.approve_transfer(amount_in_cogs) return account.send_transaction(self.contract.functions.deposit, amount_in_cogs) - def open_channel(self, account, payment_address, group_id, amount, expiration): + def open_channel(self, account: Account, payment_address, group_id, amount, expiration): return account.send_transaction( self.contract.functions.openChannel, account.signer_address, @@ -25,7 +33,9 @@ def open_channel(self, account, payment_address, group_id, amount, expiration): expiration, ) - def deposit_and_open_channel(self, account, payment_address, group_id, amount, expiration): + def deposit_and_open_channel( + self, account: Account, payment_address, group_id, amount, expiration + ): already_approved_amount = account.allowance() if amount > already_approved_amount: account.approve_transfer(amount) @@ -38,16 +48,16 @@ def deposit_and_open_channel(self, account, payment_address, group_id, amount, e expiration, ) - def channel_add_funds(self, account, channel_id, amount): + def channel_add_funds(self, account: Account, channel_id, amount): self._fund_escrow_account(account, amount) return account.send_transaction(self.contract.functions.channelAddFunds, channel_id, amount) - def channel_extend(self, account, channel_id, expiration): + def channel_extend(self, account: Account, channel_id, expiration): return account.send_transaction( self.contract.functions.channelExtend, channel_id, expiration ) - def channel_extend_and_add_funds(self, account, channel_id, expiration, amount): + def channel_extend_and_add_funds(self, account: Account, channel_id, expiration, amount): self._fund_escrow_account(account, amount) return account.send_transaction( self.contract.functions.channelExtendAndAddFunds, @@ -56,7 +66,7 @@ def channel_extend_and_add_funds(self, account, channel_id, expiration, amount): amount, ) - def _fund_escrow_account(self, account, amount): + def _fund_escrow_account(self, account: Account, amount): current_escrow_balance = self.balance(account.address) if amount > current_escrow_balance: - account.deposit_to_escrow_account(amount - current_escrow_balance) + self.deposit(amount - current_escrow_balance) diff --git a/snet/sdk/mpe/payment_channel.py b/snet/sdk/mpe/payment_channel.py index f413482..69ace6c 100644 --- a/snet/sdk/mpe/payment_channel.py +++ b/snet/sdk/mpe/payment_channel.py @@ -2,6 +2,7 @@ import importlib from eth_account.messages import defunct_hash_message +from snet.sdk import get_we3_object from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path @@ -9,13 +10,12 @@ class PaymentChannel: def __init__( self, channel_id, - w3, account, payment_channel_state_service_client, mpe_contract, ): self.channel_id = channel_id - self.web3 = w3 + self.web3 = get_we3_object() self.account = account self.mpe_contract = mpe_contract self.payment_channel_state_service_client = payment_channel_state_service_client diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index 40b05c7..f2c0369 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -4,6 +4,7 @@ from web3.types import LogReceipt +from snet.sdk import get_we3_object from snet.sdk.mpe.payment_channel import PaymentChannel from snet.contracts import get_contract_deployment_block @@ -13,8 +14,8 @@ class PaymentChannelProvider(object): - def __init__(self, w3, mpe_contract): - self.web3 = w3 + def __init__(self, mpe_contract): + self.web3 = get_we3_object() self.mpe_contract = mpe_contract self.event_topics = [ diff --git a/snet/sdk/payment_strategies/paidcall_payment_strategy.py b/snet/sdk/payment_strategies/paidcall_payment_strategy.py index 534a21e..a7127c4 100644 --- a/snet/sdk/payment_strategies/paidcall_payment_strategy.py +++ b/snet/sdk/payment_strategies/paidcall_payment_strategy.py @@ -36,13 +36,12 @@ def get_payment_metadata(self, service_client): return metadata def select_channel(self, service_client): - account = service_client.account service_client.load_open_channels() service_client.update_channel_states() payment_channels = service_client.payment_channels # picking the first pricing strategy as default for now service_call_price = self.get_price(service_client) - mpe_balance = account.escrow_balance() + mpe_balance = service_client.get_mpe_balance() default_expiration = service_client.default_channel_expiration() if len(payment_channels) < 1: diff --git a/snet/sdk/payment_strategies/prepaid_payment_strategy.py b/snet/sdk/payment_strategies/prepaid_payment_strategy.py index 5a61087..c2e3ce0 100644 --- a/snet/sdk/payment_strategies/prepaid_payment_strategy.py +++ b/snet/sdk/payment_strategies/prepaid_payment_strategy.py @@ -40,13 +40,12 @@ def get_concurrency_token_and_channel(self, service_client): return token, channel def select_channel(self, service_client): - account = service_client.account service_client.load_open_channels() service_client.update_channel_states() payment_channels = service_client.payment_channels service_call_price = self.get_price(service_client) extend_channel_fund = service_call_price * self.call_allowance - mpe_balance = account.escrow_balance() + mpe_balance = service_client.get_mpe_balance() default_expiration = service_client.default_channel_expiration() if len(payment_channels) < 1: diff --git a/snet/sdk/storage_provider/__init__.py b/snet/sdk/registry/__init__.py similarity index 100% rename from snet/sdk/storage_provider/__init__.py rename to snet/sdk/registry/__init__.py diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py new file mode 100644 index 0000000..0fabc28 --- /dev/null +++ b/snet/sdk/registry/registry_contract.py @@ -0,0 +1,76 @@ +from typing import Union + +from snet.contracts import get_contract_object + +from snet.sdk import Account, get_we3_object, config, type_converter, bytes32_to_str + + +class RegistryContract: + def __init__(self): + self.w3 = get_we3_object() + self.contract = get_contract_object(self.w3, "Registry", config.REGISTRY_CONTRACT_ADDRESS) + + # READ METHODS + + def get_org(self, org_id: str): + found, _, org_metadata_uri, owner, members, service_ids = ( + self.contract.functions.getOrganizationById(type_converter("bytes32")(org_id)).call() + ) + if not found: + # TODO: configure exceptions + raise Exception() + + # TODO: new type + return org_metadata_uri, owner, members, service_ids + + def get_service(self, org_id: str, service_id: str) -> bytes: + found, _, service_metadata_uri = self.contract.functions.getServiceRegistrationById( + type_converter("bytes32")(org_id), type_converter("bytes32")(service_id) + ).call() + if not found: + # TODO: configure exceptions + raise Exception() + + return service_metadata_uri + + def list_orgs(self): + org_list = self.contract.functions.listOrganizations().call() + return list(map(bytes32_to_str, org_list)) + + def list_service_for_org(self, org_id: str): + found, org_service_list = self.contract.functions.listServicesForOrganization( + type_converter("bytes32")(org_id) + ).call() + if not found: + # TODO: configure exceptions + raise Exception() + else: + return list(map(bytes32_to_str, org_service_list)) + + # WRITE METHODS + + def add_org_members( + self, account: Account, org_id: str, members: Union[str, list[str], None] + ): ... + + def update_org_metadata(self, account: Account, org_id: str, metadata_uri: str): ... + + def change_org_owner(self, account: Account, org_id: str, new_owner: str): ... + + def create_org( + self, account: Account, org_id: str, metadata_uri: str, members: Union[str, list[str], None] + ): ... + + def create_service(self, account: Account, org_id: str, service_id: str, metadata_uri: str): ... + + def delete_org(self, account: Account, org_id: str): ... + + def delete_service(self, account: Account, org_id: str, service_id: str): ... + + def remove_org_members( + self, account: Account, org_id: str, members_to_remove: Union[str, list[str]] + ): ... + + def update_service_metadata( + self, account: Account, org_id: str, service_id: str, metadata_uri: str + ): ... diff --git a/snet/sdk/storage_provider/service_metadata.py b/snet/sdk/registry/service_metadata.py similarity index 100% rename from snet/sdk/storage_provider/service_metadata.py rename to snet/sdk/registry/service_metadata.py diff --git a/snet/sdk/storage_provider/storage_provider.py b/snet/sdk/registry/storage_provider.py similarity index 76% rename from snet/sdk/storage_provider/storage_provider.py rename to snet/sdk/registry/storage_provider.py index edfdec1..ebd3ef6 100644 --- a/snet/sdk/storage_provider/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -1,13 +1,13 @@ -import web3 from lighthouseweb3 import Lighthouse import json +from snet.sdk.registry.registry_contract import RegistryContract from snet.sdk.utils.ipfs_utils import ( get_ipfs_client, get_from_ipfs_and_checkhash, ) from snet.sdk.utils.utils import bytesuri_to_hash, safe_extract_proto -from snet.sdk.storage_provider.service_metadata import ( +from snet.sdk.registry.service_metadata import ( MPEServiceMetadata, mpe_service_metadata_from_json, ) @@ -15,20 +15,13 @@ class StorageProvider(object): - def __init__(self, registry_contract): + def __init__(self, registry_contract: RegistryContract): self._registry_contract = registry_contract self._ipfs_client = get_ipfs_client() self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) def fetch_org_metadata(self, org_id): - org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") - - found, _, org_metadata_uri, _, _, _ = self._registry_contract.functions.getOrganizationById( - org - ).call() - if found is not True: - raise Exception('Organization with org ID "{}" not found '.format(org_id)) - + org_metadata_uri, _, _, _ = self._registry_contract.get_org(org_id) org_provider_type, org_metadata_hash = bytesuri_to_hash(org_metadata_uri) if org_provider_type == "ipfs": @@ -40,15 +33,7 @@ def fetch_org_metadata(self, org_id): return org_metadata def fetch_service_metadata(self, org_id: str, service_id: str) -> MPEServiceMetadata: - org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") - service = web3.Web3.to_bytes(text=service_id).ljust(32, b"\0") - - found, _, service_metadata_uri = ( - self._registry_contract.functions.getServiceRegistrationById(org, service).call() - ) - if found is not True: - raise Exception(f"No service '{service_id}' found in organization '{org_id}'") - + service_metadata_uri = self._registry_contract.get_service(org_id, service_id) service_provider_type, service_metadata_hash = bytesuri_to_hash(s=service_metadata_uri) if service_provider_type == "ipfs": diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index 0ffcb93..c0faa87 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -21,7 +21,7 @@ PrePaidPaymentStrategy, ) from snet.sdk.resources.root_certificate import certificate -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.service_metadata import MPEServiceMetadata from snet.sdk.custom_typing import ModuleName, ServiceStub from snet.sdk.utils.utils import ( RESOURCES_PATH, @@ -59,6 +59,7 @@ def __init__( if isinstance(payment_strategy, PrePaidPaymentStrategy): self.payment_strategy.set_concurrent_calls(options["concurrent_calls"]) self.options = options + self.mpe_contract = mpe_contract self.mpe_address = mpe_contract.contract.address self.account = account self.sdk_web3 = sdk_web3 @@ -193,6 +194,9 @@ def _generate_payment_channel_state_service_client(self) -> Any: state_service = importlib.import_module("state_service_pb2_grpc") return state_service.PaymentChannelStateServiceStub(grpc_channel) + def get_mpe_balance(self): + return self.mpe_contract.balance(self.account) + def open_channel(self, amount: int, expiration: int) -> PaymentChannel: payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) diff --git a/snet/sdk/types.py b/snet/sdk/types.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/snet/sdk/types.py @@ -0,0 +1 @@ + diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 0f2d15b..66bbf28 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -1,6 +1,8 @@ import json import sys import importlib.resources +from functools import lru_cache +from typing import Optional from urllib.parse import urlparse from pathlib import Path, PurePath import os @@ -10,19 +12,29 @@ import web3 from eth_typing import BlockNumber from grpc_tools.protoc import main as protoc +from web3 import Web3 from snet import sdk +from snet.sdk import config RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") +@lru_cache +def get_we3_object(eth_rpc_endpoint: Optional[str] = None) -> Web3: + if eth_rpc_endpoint is None: + eth_rpc_endpoint = config.ETH_RPC_ENDPOINT + + return web3.Web3(web3.HTTPProvider(eth_rpc_endpoint)) + + def safe_address_converter(a): if not web3.Web3.is_checksum_address(a): raise Exception("%s is not is not a valid Ethereum checksum address" % a) return a -def type_converter(t): +def type_converter(t: str): if t.endswith("[]"): return lambda x: list(map(type_converter(t.replace("[]", "")), json.loads(x))) else: @@ -137,7 +149,7 @@ def get_address_from_private(private_key): def get_current_block_number() -> BlockNumber: - return web3.Web3().eth.block_number + return get_we3_object().eth.block_number class add_to_path: diff --git a/tests/unit_tests/test_lib_generator.py b/tests/unit_tests/test_lib_generator.py index ec0aaeb..fa62ddb 100644 --- a/tests/unit_tests/test_lib_generator.py +++ b/tests/unit_tests/test_lib_generator.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, patch from snet.sdk.client_lib_generator import ClientLibGenerator -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider class TestClientLibGenerator(unittest.TestCase): diff --git a/tests/unit_tests/test_service_client.py b/tests/unit_tests/test_service_client.py index ee53738..f2c8e1d 100644 --- a/tests/unit_tests/test_service_client.py +++ b/tests/unit_tests/test_service_client.py @@ -8,7 +8,7 @@ from snet.sdk.mpe.mpe_contract import MPEContract from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider from snet.sdk.service_client import ServiceClient -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.service_metadata import MPEServiceMetadata class TestServiceClient(unittest.TestCase): From 9051518a82e170a821e4bf63c3b6d18e1b152579 Mon Sep 17 00:00:00 2001 From: Arondondon Date: Mon, 16 Mar 2026 18:43:46 +0300 Subject: [PATCH 2/7] fix: change imports --- snet/sdk/account.py | 3 +-- snet/sdk/mpe/mpe_contract.py | 5 +++-- snet/sdk/mpe/payment_channel.py | 3 +-- snet/sdk/mpe/payment_channel_provider.py | 2 +- snet/sdk/registry/registry_contract.py | 4 +++- snet/sdk/utils/utils.py | 2 +- 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/snet/sdk/account.py b/snet/sdk/account.py index 0003758..d8e6d12 100644 --- a/snet/sdk/account.py +++ b/snet/sdk/account.py @@ -3,9 +3,8 @@ from snet.contracts import get_contract_object -from snet.sdk import get_we3_object from snet.sdk.config import config -from snet.sdk.utils.utils import get_address_from_private, normalize_private_key +from snet.sdk.utils.utils import get_address_from_private, normalize_private_key, get_we3_object DEFAULT_GAS = 300000 TRANSACTION_TIMEOUT = 500 diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index d33ea5d..07a22cc 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -2,8 +2,9 @@ from snet.contracts import get_contract_object -from snet.sdk import get_we3_object, config, Account - +from snet.sdk.account import Account +from snet.sdk.config import config +from snet.sdk.utils.utils import get_we3_object class MPEContract: def __init__(self): diff --git a/snet/sdk/mpe/payment_channel.py b/snet/sdk/mpe/payment_channel.py index 69ace6c..616c70e 100644 --- a/snet/sdk/mpe/payment_channel.py +++ b/snet/sdk/mpe/payment_channel.py @@ -2,8 +2,7 @@ import importlib from eth_account.messages import defunct_hash_message -from snet.sdk import get_we3_object -from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path +from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, get_we3_object class PaymentChannel: diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index f2c0369..b656d52 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -4,7 +4,7 @@ from web3.types import LogReceipt -from snet.sdk import get_we3_object +from snet.sdk.utils.utils import get_we3_object from snet.sdk.mpe.payment_channel import PaymentChannel from snet.contracts import get_contract_deployment_block diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index 0fabc28..ad6a279 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -2,7 +2,9 @@ from snet.contracts import get_contract_object -from snet.sdk import Account, get_we3_object, config, type_converter, bytes32_to_str +from snet.sdk.account import Account +from snet.sdk.config import config +from snet.sdk.utils.utils import type_converter, bytes32_to_str, get_we3_object class RegistryContract: diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 66bbf28..6fd099d 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -15,7 +15,7 @@ from web3 import Web3 from snet import sdk -from snet.sdk import config +from snet.sdk.config import config RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") From 1a1973753d97f29a6a766b0f0414956b463cc0ae Mon Sep 17 00:00:00 2001 From: Arondondon Date: Tue, 17 Mar 2026 14:09:42 +0300 Subject: [PATCH 3/7] feat: add new types, refactor storage provider --- snet/sdk/registry/registry_contract.py | 30 ++++++++---- snet/sdk/registry/storage_provider.py | 44 +++++++++-------- snet/sdk/types.py | 66 ++++++++++++++++++++++++++ snet/sdk/utils/utils.py | 14 ++++-- 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index ad6a279..c3034dd 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -4,6 +4,7 @@ from snet.sdk.account import Account from snet.sdk.config import config +from snet.sdk.types import RawOrgData, OrgData, ServiceData, RawServiceData from snet.sdk.utils.utils import type_converter, bytes32_to_str, get_we3_object @@ -14,32 +15,43 @@ def __init__(self): # READ METHODS - def get_org(self, org_id: str): - found, _, org_metadata_uri, owner, members, service_ids = ( + def get_org(self, org_id: str) -> OrgData: + found, found_org_id, org_metadata_uri, owner, members, service_ids = ( self.contract.functions.getOrganizationById(type_converter("bytes32")(org_id)).call() ) if not found: # TODO: configure exceptions raise Exception() - # TODO: new type - return org_metadata_uri, owner, members, service_ids + return OrgData.from_raw_org_data(RawOrgData( + org_id = found_org_id, + metadata_uri = org_metadata_uri, + owner = owner, + members = members, + services = service_ids + )) - def get_service(self, org_id: str, service_id: str) -> bytes: - found, _, service_metadata_uri = self.contract.functions.getServiceRegistrationById( + def get_service(self, org_id: str, service_id: str) -> ServiceData: + found, found_service_id, service_metadata_uri = self.contract.functions.getServiceRegistrationById( type_converter("bytes32")(org_id), type_converter("bytes32")(service_id) ).call() if not found: # TODO: configure exceptions raise Exception() - return service_metadata_uri + return ServiceData.from_raw_service_data( + RawServiceData( + service_id = found_service_id, + metadata_uri = service_metadata_uri + ), + org_id = org_id + ) - def list_orgs(self): + def list_orgs(self) -> list[str]: org_list = self.contract.functions.listOrganizations().call() return list(map(bytes32_to_str, org_list)) - def list_service_for_org(self, org_id: str): + def list_service_for_org(self, org_id: str) -> list[str]: found, org_service_list = self.contract.functions.listServicesForOrganization( type_converter("bytes32")(org_id) ).call() diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py index ebd3ef6..2bfa9b2 100644 --- a/snet/sdk/registry/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -1,7 +1,10 @@ +from typing import Any + from lighthouseweb3 import Lighthouse import json from snet.sdk.registry.registry_contract import RegistryContract +from snet.sdk.types import StorageType, FileURI from snet.sdk.utils.ipfs_utils import ( get_ipfs_client, get_from_ipfs_and_checkhash, @@ -21,27 +24,17 @@ def __init__(self, registry_contract: RegistryContract): self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) def fetch_org_metadata(self, org_id): - org_metadata_uri, _, _, _ = self._registry_contract.get_org(org_id) - org_provider_type, org_metadata_hash = bytesuri_to_hash(org_metadata_uri) + org = self._registry_contract.get_org(org_id) - if org_provider_type == "ipfs": - org_metadata_json = get_from_ipfs_and_checkhash(self._ipfs_client, org_metadata_hash) - else: - org_metadata_json, _ = self.lighthouse_client.download(org_metadata_hash) + org_metadata_json = self._get_from_storage(org.metadata_uri) org_metadata = json.loads(org_metadata_json) return org_metadata def fetch_service_metadata(self, org_id: str, service_id: str) -> MPEServiceMetadata: - service_metadata_uri = self._registry_contract.get_service(org_id, service_id) - service_provider_type, service_metadata_hash = bytesuri_to_hash(s=service_metadata_uri) + service = self._registry_contract.get_service(org_id, service_id) - if service_provider_type == "ipfs": - service_metadata_json = get_from_ipfs_and_checkhash( - self._ipfs_client, service_metadata_hash - ) - else: - service_metadata_json, _ = self.lighthouse_client.download(cid=service_metadata_hash) + service_metadata_json = self._get_from_storage(service.metadata_uri) service_metadata = mpe_service_metadata_from_json(service_metadata_json) return service_metadata @@ -60,17 +53,26 @@ def enhance_service_metadata(self, org_id, service_id): return service_metadata - def fetch_and_extract_proto(self, service_api_source, protodir): + def fetch_and_extract_proto(self, service_api_source, proto_dir): try: - proto_provider_type, service_api_source = bytesuri_to_hash( + tar_uri = bytesuri_to_hash( service_api_source, to_decode=False ) except Exception: - proto_provider_type = "ipfs" + # TODO: change exception based on bytesuri_to_hash function + tar_uri = FileURI(storage_type = StorageType.IPFS, uri_hash = service_api_source) + + spec_tar = self._get_from_storage(tar_uri) + + safe_extract_proto(spec_tar, proto_dir) - if proto_provider_type == "ipfs": - spec_tar = get_from_ipfs_and_checkhash(self._ipfs_client, service_api_source) + def _get_from_storage(self, uri: FileURI) -> Any: + if uri.storage_type == StorageType.IPFS: + file = get_from_ipfs_and_checkhash(self._ipfs_client, uri.uri_hash) + elif uri.storage_type == StorageType.FILECOIN: + file, _ = self.lighthouse_client.download(uri.uri_hash) else: - spec_tar, _ = self.lighthouse_client.download(service_api_source) + # TODO: configure exceptions + raise Exception() - safe_extract_proto(spec_tar, protodir) + return file \ No newline at end of file diff --git a/snet/sdk/types.py b/snet/sdk/types.py index 8b13789..0f7363d 100644 --- a/snet/sdk/types.py +++ b/snet/sdk/types.py @@ -1 +1,67 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Union +from grpc import services + +from snet.sdk.utils.utils import bytes32_to_str, bytesuri_to_hash + + +@dataclass +class RawOrgData: + org_id: bytes + metadata_uri: bytes + owner: str + members: list[str] + services: list[bytes] # IDs + + +@dataclass +class RawServiceData: + service_id: bytes + metadata_uri: bytes + + +class StorageType(Enum): + IPFS = "ipfs" + FILECOIN = "filecoin" + + +@dataclass +class FileURI: + storage_type: StorageType + uri_hash: str + + +@dataclass +class OrgData: + org_id: str + metadata_uri: FileURI + owner: str + members: list[str] + services: list[str] # IDs + + @classmethod + def from_raw_org_data(cls, raw_org_data: RawOrgData): + return OrgData( + org_id = bytes32_to_str(raw_org_data.org_id), + metadata_uri = bytesuri_to_hash(raw_org_data.metadata_uri), + owner = raw_org_data.owner, + members = raw_org_data.members, + services = list(map(bytes32_to_str, raw_org_data.services)) + ) + + +@dataclass +class ServiceData: + org_id: str + service_id: str + metadata_uri: FileURI + + @classmethod + def from_raw_service_data(cls, raw_service_data: RawServiceData, org_id: Union[str, bytes]): + return ServiceData( + org_id = bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, + service_id = bytes32_to_str(raw_service_data.service_id), + metadata_uri = bytesuri_to_hash(raw_service_data.metadata_uri) + ) \ No newline at end of file diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 6fd099d..70af508 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -16,6 +16,7 @@ from snet import sdk from snet.sdk.config import config +from snet.sdk.types import StorageType, FileURI RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") @@ -178,11 +179,14 @@ def find_file_by_keyword(directory, keyword, exclude=None): def bytesuri_to_hash(s, to_decode=True): if to_decode: s = s.rstrip(b"\0").decode("ascii") - if s.startswith("ipfs://"): - return "ipfs", s[7:] - elif s.startswith("filecoin://"): - return "filecoin", s[11:] - else: + try: + storage_type, storage_hash = s.split("://") + return FileURI( + StorageType(storage_type), + storage_hash + ) + except ValueError: + # TODO: configure exceptions raise Exception("We support only ipfs and filecoin uri in Registry") From a3895fc86bfec4eba9e8b4d959851f471b3b36f1 Mon Sep 17 00:00:00 2001 From: Arondondon Date: Tue, 17 Mar 2026 15:02:42 +0300 Subject: [PATCH 4/7] fix: apply linter --- snet/sdk/__init__.py | 8 +++---- snet/sdk/client_lib_generator.py | 26 ++++++++++----------- snet/sdk/mpe/mpe_contract.py | 1 + snet/sdk/registry/registry_contract.py | 31 +++++++++++++------------- snet/sdk/registry/storage_provider.py | 8 +++---- snet/sdk/types.py | 22 +++++++++--------- snet/sdk/utils/utils.py | 5 +---- tests/unit_tests/test_lib_generator.py | 24 ++++++++++---------- 8 files changed, 60 insertions(+), 65 deletions(-) diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 04d5c2b..3e9bb52 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -80,7 +80,7 @@ def create_service_client( if force_update: lib_generator.generate_client_library() else: - path_to_pb_files = lib_generator.protodir + path_to_pb_files = lib_generator.proto_dir pb_2_file_name = find_file_by_keyword( path_to_pb_files, keyword="pb2.py", exclude=["training"] ) @@ -118,13 +118,13 @@ def create_service_client( self.w3, pb2_module, self.payment_channel_provider, - lib_generator.protodir, + lib_generator.proto_dir, lib_generator.training_added(), ) return _service_client def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]: - path_to_pb_files = str(lib_generator.protodir) + path_to_pb_files = str(lib_generator.proto_dir) module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator) sys.path.append(path_to_pb_files) try: @@ -140,7 +140,7 @@ def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStu raise Exception(f"Error importing module: {e}") def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName: - path_to_pb_files = lib_generator.protodir + path_to_pb_files = lib_generator.proto_dir file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"]) module_name = os.path.splitext(file_name)[0] return ModuleName(module_name) diff --git a/snet/sdk/client_lib_generator.py b/snet/sdk/client_lib_generator.py index 234571b..bd58ec4 100644 --- a/snet/sdk/client_lib_generator.py +++ b/snet/sdk/client_lib_generator.py @@ -11,21 +11,21 @@ def __init__( metadata_provider: StorageProvider, org_id: str, service_id: str, - protodir: Path | None = None, + proto_dir: Path | None = None, ): self._metadata_provider: StorageProvider = metadata_provider self.org_id: str = org_id self.service_id: str = service_id self.language: str = "python" - self.protodir: Path = protodir if protodir else Path.home().joinpath(".snet") + self.proto_dir: Path = proto_dir if proto_dir else Path.home().joinpath(".snet") self.generate_directories_by_params() def generate_client_library(self) -> None: try: self.receive_proto_files() compilation_result = compile_proto( - entry_path=self.protodir, - codegen_dir=self.protodir, + entry_path=self.proto_dir, + codegen_dir=self.proto_dir, target_language=self.language, add_training=self.training_added(), ) @@ -33,19 +33,19 @@ def generate_client_library(self) -> None: print( f'client libraries for service with id "{self.service_id}" ' f'in org with id "{self.org_id}" ' - f"generated at {self.protodir}" + f"generated at {self.proto_dir}" ) except Exception as e: print(str(e)) def generate_directories_by_params(self) -> None: - if not self.protodir.is_absolute(): - self.protodir = Path.cwd().joinpath(self.protodir) + if not self.proto_dir.is_absolute(): + self.proto_dir = Path.cwd().joinpath(self.proto_dir) self.create_service_client_libraries_path() def create_service_client_libraries_path(self) -> None: - self.protodir = self.protodir.joinpath(self.org_id, self.service_id, self.language) - self.protodir.mkdir(parents=True, exist_ok=True) + self.proto_dir = self.proto_dir.joinpath(self.org_id, self.service_id, self.language) + self.proto_dir.mkdir(parents=True, exist_ok=True) def receive_proto_files(self) -> None: metadata = self._metadata_provider.fetch_service_metadata( @@ -54,17 +54,17 @@ def receive_proto_files(self) -> None: service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") # Receive proto files - if self.protodir.exists(): - self._metadata_provider.fetch_and_extract_proto(service_api_source, self.protodir) + if self.proto_dir.exists(): + self._metadata_provider.fetch_and_extract_proto(service_api_source, self.proto_dir) else: raise Exception("Directory for storing proto files is not found") def training_added(self) -> bool: - files = os.listdir(self.protodir) + files = os.listdir(self.proto_dir) for file in files: if ".proto" not in file: continue - with open(self.protodir.joinpath(file), "r") as f: + with open(self.proto_dir.joinpath(file), "r") as f: proto_text = f.read() if 'import "training.proto";' in proto_text: return True diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index 07a22cc..2a847f6 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -6,6 +6,7 @@ from snet.sdk.config import config from snet.sdk.utils.utils import get_we3_object + class MPEContract: def __init__(self): self.w3 = get_we3_object() diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index c3034dd..ccdba03 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -23,28 +23,29 @@ def get_org(self, org_id: str) -> OrgData: # TODO: configure exceptions raise Exception() - return OrgData.from_raw_org_data(RawOrgData( - org_id = found_org_id, - metadata_uri = org_metadata_uri, - owner = owner, - members = members, - services = service_ids - )) + return OrgData.from_raw_org_data( + RawOrgData( + org_id=found_org_id, + metadata_uri=org_metadata_uri, + owner=owner, + members=members, + services=service_ids, + ) + ) def get_service(self, org_id: str, service_id: str) -> ServiceData: - found, found_service_id, service_metadata_uri = self.contract.functions.getServiceRegistrationById( - type_converter("bytes32")(org_id), type_converter("bytes32")(service_id) - ).call() + found, found_service_id, service_metadata_uri = ( + self.contract.functions.getServiceRegistrationById( + type_converter("bytes32")(org_id), type_converter("bytes32")(service_id) + ).call() + ) if not found: # TODO: configure exceptions raise Exception() return ServiceData.from_raw_service_data( - RawServiceData( - service_id = found_service_id, - metadata_uri = service_metadata_uri - ), - org_id = org_id + RawServiceData(service_id=found_service_id, metadata_uri=service_metadata_uri), + org_id=org_id, ) def list_orgs(self) -> list[str]: diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py index 2bfa9b2..f11d698 100644 --- a/snet/sdk/registry/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -55,12 +55,10 @@ def enhance_service_metadata(self, org_id, service_id): def fetch_and_extract_proto(self, service_api_source, proto_dir): try: - tar_uri = bytesuri_to_hash( - service_api_source, to_decode=False - ) + tar_uri = bytesuri_to_hash(service_api_source, to_decode=False) except Exception: # TODO: change exception based on bytesuri_to_hash function - tar_uri = FileURI(storage_type = StorageType.IPFS, uri_hash = service_api_source) + tar_uri = FileURI(storage_type=StorageType.IPFS, uri_hash=service_api_source) spec_tar = self._get_from_storage(tar_uri) @@ -75,4 +73,4 @@ def _get_from_storage(self, uri: FileURI) -> Any: # TODO: configure exceptions raise Exception() - return file \ No newline at end of file + return file diff --git a/snet/sdk/types.py b/snet/sdk/types.py index 0f7363d..61cc249 100644 --- a/snet/sdk/types.py +++ b/snet/sdk/types.py @@ -2,8 +2,6 @@ from enum import Enum from typing import Union -from grpc import services - from snet.sdk.utils.utils import bytes32_to_str, bytesuri_to_hash @@ -13,7 +11,7 @@ class RawOrgData: metadata_uri: bytes owner: str members: list[str] - services: list[bytes] # IDs + services: list[bytes] # IDs @dataclass @@ -44,11 +42,11 @@ class OrgData: @classmethod def from_raw_org_data(cls, raw_org_data: RawOrgData): return OrgData( - org_id = bytes32_to_str(raw_org_data.org_id), - metadata_uri = bytesuri_to_hash(raw_org_data.metadata_uri), - owner = raw_org_data.owner, - members = raw_org_data.members, - services = list(map(bytes32_to_str, raw_org_data.services)) + org_id=bytes32_to_str(raw_org_data.org_id), + metadata_uri=bytesuri_to_hash(raw_org_data.metadata_uri), + owner=raw_org_data.owner, + members=raw_org_data.members, + services=list(map(bytes32_to_str, raw_org_data.services)), ) @@ -61,7 +59,7 @@ class ServiceData: @classmethod def from_raw_service_data(cls, raw_service_data: RawServiceData, org_id: Union[str, bytes]): return ServiceData( - org_id = bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, - service_id = bytes32_to_str(raw_service_data.service_id), - metadata_uri = bytesuri_to_hash(raw_service_data.metadata_uri) - ) \ No newline at end of file + org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, + service_id=bytes32_to_str(raw_service_data.service_id), + metadata_uri=bytesuri_to_hash(raw_service_data.metadata_uri), + ) diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 70af508..64122f3 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -181,10 +181,7 @@ def bytesuri_to_hash(s, to_decode=True): s = s.rstrip(b"\0").decode("ascii") try: storage_type, storage_hash = s.split("://") - return FileURI( - StorageType(storage_type), - storage_hash - ) + return FileURI(StorageType(storage_type), storage_hash) except ValueError: # TODO: configure exceptions raise Exception("We support only ipfs and filecoin uri in Registry") diff --git a/tests/unit_tests/test_lib_generator.py b/tests/unit_tests/test_lib_generator.py index fa62ddb..1bcfac0 100644 --- a/tests/unit_tests/test_lib_generator.py +++ b/tests/unit_tests/test_lib_generator.py @@ -18,29 +18,29 @@ def setUp(self): metadata_provider=self.mock_metadata_provider, org_id=self.org_id, service_id=self.service_id, - protodir=self.protodir, + proto_dir=self.protodir, ) @patch("pathlib.Path.mkdir") def test_generate_directories_by_params_by_absolute_path(self, mock_mkdir): expected_library_dir = self.protodir.joinpath(self.org_id, self.service_id, self.language) self.generator.generate_directories_by_params() - self.assertEqual(self.generator.protodir, expected_library_dir) + self.assertEqual(self.generator.proto_dir, expected_library_dir) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) @patch("pathlib.Path.mkdir") def test_generate_directories_by_params_by_relative_path(self, mock_mkdir): - self.generator.protodir = Path(".snet_test") + self.generator.proto_dir = Path(".snet_test") expected_library_dir = Path.cwd().joinpath( - self.generator.protodir, self.org_id, self.service_id, self.language + self.generator.proto_dir, self.org_id, self.service_id, self.language ) self.generator.generate_directories_by_params() - self.assertEqual(self.generator.protodir, expected_library_dir) + self.assertEqual(self.generator.proto_dir, expected_library_dir) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) def test_create_service_client_libraries_path(self): mock_protodir = Mock(spec=Path) - self.generator.protodir = mock_protodir + self.generator.proto_dir = mock_protodir mock_library_path = Mock(spec=Path) mock_protodir.joinpath.return_value = mock_library_path @@ -52,7 +52,7 @@ def test_create_service_client_libraries_path(self): mock_library_path.mkdir.assert_called_once_with(parents=True, exist_ok=True) # Assert that the protodir is updated correctly - self.assertEqual(self.generator.protodir, mock_library_path) + self.assertEqual(self.generator.proto_dir, mock_library_path) def test_receive_proto_files_success(self): # Set up mocks @@ -61,8 +61,8 @@ def test_receive_proto_files_success(self): "model_ipfs_hash": os.getenv("MODEL_IPFS_HASH"), } self.mock_metadata_provider.fetch_service_metadata.return_value = mock_metadata - self.generator.protodir = Mock() - self.generator.protodir.exists.return_value = True + self.generator.proto_dir = Mock() + self.generator.proto_dir.exists.return_value = True # Call the method self.generator.receive_proto_files() @@ -75,12 +75,12 @@ def test_receive_proto_files_success(self): org_id=self.org_id, service_id=self.service_id ) self.mock_metadata_provider.fetch_and_extract_proto.assert_called_once_with( - service_api_source, self.generator.protodir + service_api_source, self.generator.proto_dir ) def test_receive_proto_files_failed(self): - self.generator.protodir = Mock() - self.generator.protodir.exists.return_value = False + self.generator.proto_dir = Mock() + self.generator.proto_dir.exists.return_value = False with self.assertRaises(Exception) as context: self.generator.receive_proto_files() From f2af65eeadb32beb899b5823961b9ff005e2c766 Mon Sep 17 00:00:00 2001 From: Arondondon Date: Wed, 18 Mar 2026 17:45:48 +0300 Subject: [PATCH 5/7] feat: start implementing service metadata --- pyproject.toml | 1 + requirements.txt | 1 + snet/sdk/__init__.py | 3 +- snet/sdk/mpe/payment_channel_provider.py | 4 +- snet/sdk/registry/registry_contract.py | 12 +++- snet/sdk/registry/service_metadata.py | 73 ++++++++++++++++++++++++ snet/sdk/types.py | 21 ------- snet/sdk/utils/utils.py | 24 +++++++- 8 files changed, 110 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2786d02..4862524 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "snet-contracts==1.0.1", "lighthouseweb3~=0.1.4", "py-multihash~=3.0", + "pydantic~=2.11", "pydantic-settings~=2.13" ] diff --git a/requirements.txt b/requirements.txt index 4bc7135..5731954 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ ipfshttpclient==0.4.13.2 snet-contracts==1.0.1 lighthouseweb3~=0.1.4 py-multihash~=3.0 +pydantic~=2.11 pydantic-settings~=2.13 \ No newline at end of file diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 3e9bb52..641ea08 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -157,7 +157,8 @@ def _get_group_by_group_name( for group in service_metadata["groups"]: if group["group_name"] == group_name: return group - return {} + # TODO: configure exceptions + raise Exception() def _get_service_group_details( self, service_metadata: MPEServiceMetadata, group_name: str diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index b656d52..0a12fff 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -7,6 +7,7 @@ from snet.sdk.utils.utils import get_we3_object from snet.sdk.mpe.payment_channel import PaymentChannel from snet.contracts import get_contract_deployment_block +from snet.sdk.mpe.mpe_contract import MPEContract BLOCKS_PER_BATCH = 50000 @@ -14,7 +15,7 @@ class PaymentChannelProvider(object): - def __init__(self, mpe_contract): + def __init__(self, mpe_contract: MPEContract): self.web3 = get_we3_object() self.mpe_contract = mpe_contract @@ -139,7 +140,6 @@ def get_past_open_channels( map( lambda channel: PaymentChannel( channel["channel_id"], - self.web3, account, payment_channel_state_service_client, self.mpe_contract, diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index ccdba03..50bd0be 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -5,7 +5,13 @@ from snet.sdk.account import Account from snet.sdk.config import config from snet.sdk.types import RawOrgData, OrgData, ServiceData, RawServiceData -from snet.sdk.utils.utils import type_converter, bytes32_to_str, get_we3_object +from snet.sdk.utils.utils import ( + type_converter, + bytes32_to_str, + get_we3_object, + convert_raw_service_data, + convert_raw_org_data, +) class RegistryContract: @@ -23,7 +29,7 @@ def get_org(self, org_id: str) -> OrgData: # TODO: configure exceptions raise Exception() - return OrgData.from_raw_org_data( + return convert_raw_org_data( RawOrgData( org_id=found_org_id, metadata_uri=org_metadata_uri, @@ -43,7 +49,7 @@ def get_service(self, org_id: str, service_id: str) -> ServiceData: # TODO: configure exceptions raise Exception() - return ServiceData.from_raw_service_data( + return convert_raw_service_data( RawServiceData(service_id=found_service_id, metadata_uri=service_metadata_uri), org_id=org_id, ) diff --git a/snet/sdk/registry/service_metadata.py b/snet/sdk/registry/service_metadata.py index 9ab11b1..c734029 100644 --- a/snet/sdk/registry/service_metadata.py +++ b/snet/sdk/registry/service_metadata.py @@ -39,9 +39,13 @@ import re import json import base64 +import secrets from collections import defaultdict from enum import Enum +from typing import Literal, Any, Optional + +from pydantic import BaseModel, Field, model_validator, ValidationInfo from snet.sdk.utils.utils import is_valid_endpoint @@ -63,6 +67,75 @@ def is_single_value(asset_type): return True +def generate_group_id() -> str: + return base64.b64encode(secrets.token_bytes(32)).decode() + + +class Pricing(BaseModel): + price_model: Literal["fixed_price", "method_price"] = Field(default="fixed_price") + price_in_cogs: int = Field(ge=1) + default: bool = Field(default=True) + + +class Group(BaseModel): + group_name: str = Field(min_length=1, default="default_group") + group_id: str = Field(default_factory=generate_group_id, init=False) + free_calls: int = Field(ge=1) + free_call_signer_address: str + daemon_addresses: list[str] + + +class ServiceDescription(BaseModel): ... + + +class Media(BaseModel): ... + + +class Contributor(BaseModel): ... + + +class ServiceMetadata(BaseModel): + version: int = Field(ge=1, default=1) + display_name: str + encoding: Literal["proto", "json"] = Field(default="proto") + service_type: Literal["grpc", "http", "jsonrpc"] + service_api_source: Optional[str] = Field(min_length=1, default=None, init=False) + model_ipfs_hash: Optional[str] = Field(min_length=1, default=None, init=False, deprecated=True) + mpe_address: str + groups: list[Group] + service_description: ServiceDescription + media: list[Media] + contributors: list[Contributor] + tags: list[str] + + @model_validator(mode="before") + @classmethod + def restrict_deprecated_fields(cls, data: Any, info: ValidationInfo) -> Any: + if not isinstance(data, dict): + return data + + is_fetching = info.context and info.context.get("from_storage") is True + + if not is_fetching and data.get("model_ipfs_hash"): + # TODO: configure exceptions + raise ValueError( + "The 'model_ipfs_hash' field is deprecated and cannot be used " + "to create new metadata. Please use 'service_api_source' instead." + ) + + return data + + def generate_final_json(self): + if self.service_api_source is None: + if self.model_ipfs_hash is None: + # TODO: configure exceptions + raise ValueError("The 'service_api_source' field is missing!") + else: + self.service_api_source, self.model_ipfs_hash = self.model_ipfs_hash, None + + return self.model_dump_json(indent=2, exclude_none=True) + + # TODO: we should use some standard solution here class MPEServiceMetadata: def __init__(self): diff --git a/snet/sdk/types.py b/snet/sdk/types.py index 61cc249..216ed18 100644 --- a/snet/sdk/types.py +++ b/snet/sdk/types.py @@ -1,8 +1,5 @@ from dataclasses import dataclass from enum import Enum -from typing import Union - -from snet.sdk.utils.utils import bytes32_to_str, bytesuri_to_hash @dataclass @@ -39,27 +36,9 @@ class OrgData: members: list[str] services: list[str] # IDs - @classmethod - def from_raw_org_data(cls, raw_org_data: RawOrgData): - return OrgData( - org_id=bytes32_to_str(raw_org_data.org_id), - metadata_uri=bytesuri_to_hash(raw_org_data.metadata_uri), - owner=raw_org_data.owner, - members=raw_org_data.members, - services=list(map(bytes32_to_str, raw_org_data.services)), - ) - @dataclass class ServiceData: org_id: str service_id: str metadata_uri: FileURI - - @classmethod - def from_raw_service_data(cls, raw_service_data: RawServiceData, org_id: Union[str, bytes]): - return ServiceData( - org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, - service_id=bytes32_to_str(raw_service_data.service_id), - metadata_uri=bytesuri_to_hash(raw_service_data.metadata_uri), - ) diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 64122f3..9e4c876 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -2,7 +2,7 @@ import sys import importlib.resources from functools import lru_cache -from typing import Optional +from typing import Optional, Union from urllib.parse import urlparse from pathlib import Path, PurePath import os @@ -16,11 +16,31 @@ from snet import sdk from snet.sdk.config import config -from snet.sdk.types import StorageType, FileURI +from snet.sdk.types import StorageType, FileURI, RawOrgData, OrgData, RawServiceData, ServiceData RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") +def convert_raw_org_data(raw_org_data: RawOrgData) -> OrgData: + return OrgData( + org_id=bytes32_to_str(raw_org_data.org_id), + metadata_uri=bytesuri_to_hash(raw_org_data.metadata_uri), + owner=raw_org_data.owner, + members=raw_org_data.members, + services=list(map(bytes32_to_str, raw_org_data.services)), + ) + + +def convert_raw_service_data( + raw_service_data: RawServiceData, org_id: Union[str, bytes] +) -> ServiceData: + return ServiceData( + org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, + service_id=bytes32_to_str(raw_service_data.service_id), + metadata_uri=bytesuri_to_hash(raw_service_data.metadata_uri), + ) + + @lru_cache def get_we3_object(eth_rpc_endpoint: Optional[str] = None) -> Web3: if eth_rpc_endpoint is None: From 5df5bf3fffdc2db7ee5ff15c297b548f786759ce Mon Sep 17 00:00:00 2001 From: Arondondon Date: Fri, 20 Mar 2026 16:44:16 +0300 Subject: [PATCH 6/7] feat: end implementing service metadata, rework storage provider --- snet/sdk/__init__.py | 2 +- snet/sdk/custom_typing.py | 5 -- snet/sdk/registry/models.py | 92 +++++++++++++++++++++ snet/sdk/registry/registry_contract.py | 8 +- snet/sdk/registry/service_metadata.py | 78 +++++++++++++----- snet/sdk/registry/storage_provider.py | 107 ++++++++++++++++++------- snet/sdk/service_client.py | 2 +- snet/sdk/types.py | 45 +---------- snet/sdk/utils/ipfs_utils.py | 42 ---------- snet/sdk/utils/utils.py | 56 +------------ 10 files changed, 241 insertions(+), 196 deletions(-) delete mode 100644 snet/sdk/custom_typing.py create mode 100644 snet/sdk/registry/models.py delete mode 100644 snet/sdk/utils/ipfs_utils.py diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 641ea08..2b6825c 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -33,7 +33,7 @@ ) from snet.sdk.service_client import ServiceClient from snet.sdk.registry.storage_provider import StorageProvider -from snet.sdk.custom_typing import ModuleName, ServiceStub +from snet.sdk.types import ModuleName, ServiceStub from snet.sdk.utils.utils import ( find_file_by_keyword, get_we3_object, diff --git a/snet/sdk/custom_typing.py b/snet/sdk/custom_typing.py deleted file mode 100644 index 81ed725..0000000 --- a/snet/sdk/custom_typing.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Any, NewType - - -ModuleName = NewType("ModuleName", str) -ServiceStub = NewType("ServiceStub", Any) diff --git a/snet/sdk/registry/models.py b/snet/sdk/registry/models.py new file mode 100644 index 0000000..2ac8b69 --- /dev/null +++ b/snet/sdk/registry/models.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Union + +from snet.sdk.utils.utils import bytes32_to_str + + +@dataclass +class RawOrgData: + org_id: bytes + metadata_uri: bytes + owner: str + members: list[str] + services: list[bytes] # IDs + + +@dataclass +class RawServiceData: + service_id: bytes + metadata_uri: bytes + + +class StorageType(Enum): + IPFS = "ipfs" + FILECOIN = "filecoin" + + +@dataclass +class FileURI: + storage_type: StorageType + uri_hash: str + + @classmethod + def from_raw_uri(cls, string_uri: Union[str, bytes]) -> "FileURI": + if not string_uri: + raise ValueError("'string_uri' cannot be empty!") + + if isinstance(string_uri, bytes): + string_uri = string_uri.rstrip(b"\0").decode("ascii") + + try: + s_t_str, u_h = string_uri.split("://") + except ValueError: + s_t_str = "ipfs" + u_h = string_uri + + s_t = StorageType(s_t_str) + + return cls(s_t, u_h) + + def __str__(self) -> str: + return f"{self.storage_type.value}://{self.uri_hash}" + + @classmethod + def normalize_string_uri(cls, string_uri: str) -> str: + return str(FileURI.from_raw_uri(string_uri)) + + +@dataclass +class OrgData: + org_id: str + metadata_uri: FileURI + owner: str + members: list[str] + services: list[str] # IDs + + @classmethod + def from_raw_data(cls, raw_org_data: RawOrgData) -> "OrgData": + return cls( + org_id=bytes32_to_str(raw_org_data.org_id), + metadata_uri=FileURI.from_raw_uri(raw_org_data.metadata_uri), + owner=raw_org_data.owner, + members=raw_org_data.members, + services=list(map(bytes32_to_str, raw_org_data.services)), + ) + + +@dataclass +class ServiceData: + org_id: str + service_id: str + metadata_uri: FileURI + + @classmethod + def from_raw_data( + cls, raw_service_data: RawServiceData, org_id: Union[str, bytes] + ) -> "ServiceData": + return cls( + org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, + service_id=bytes32_to_str(raw_service_data.service_id), + metadata_uri=FileURI.from_raw_uri(raw_service_data.metadata_uri), + ) diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index 50bd0be..5960014 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -4,13 +4,11 @@ from snet.sdk.account import Account from snet.sdk.config import config -from snet.sdk.types import RawOrgData, OrgData, ServiceData, RawServiceData +from snet.sdk.registry.models import RawOrgData, OrgData, ServiceData, RawServiceData from snet.sdk.utils.utils import ( type_converter, bytes32_to_str, get_we3_object, - convert_raw_service_data, - convert_raw_org_data, ) @@ -29,7 +27,7 @@ def get_org(self, org_id: str) -> OrgData: # TODO: configure exceptions raise Exception() - return convert_raw_org_data( + return OrgData.from_raw_data( RawOrgData( org_id=found_org_id, metadata_uri=org_metadata_uri, @@ -49,7 +47,7 @@ def get_service(self, org_id: str, service_id: str) -> ServiceData: # TODO: configure exceptions raise Exception() - return convert_raw_service_data( + return ServiceData.from_raw_data( RawServiceData(service_id=found_service_id, metadata_uri=service_metadata_uri), org_id=org_id, ) diff --git a/snet/sdk/registry/service_metadata.py b/snet/sdk/registry/service_metadata.py index c734029..97de7ca 100644 --- a/snet/sdk/registry/service_metadata.py +++ b/snet/sdk/registry/service_metadata.py @@ -47,6 +47,7 @@ from pydantic import BaseModel, Field, model_validator, ValidationInfo +from snet.sdk.registry.models import FileURI from snet.sdk.utils.utils import is_valid_endpoint @@ -67,46 +68,76 @@ def is_single_value(asset_type): return True +class FileType(Enum): + IMAGE = "image" + VIDEO = "video" + ARCHIVE = "archive" + + +# class AssetType(Enum): +# HERO_IMAGE = "hero_image" +# PROTO_FILE = "proto_file" +# DEMO_COMPONENT = "demo_component" + + +class ServiceType(Enum): + GRPC = "grpc" + HTTP = "http" + JSONRPC = "jsonrpc" + + def generate_group_id() -> str: return base64.b64encode(secrets.token_bytes(32)).decode() class Pricing(BaseModel): price_model: Literal["fixed_price", "method_price"] = Field(default="fixed_price") - price_in_cogs: int = Field(ge=1) + price_in_cogs: int = Field(ge=1, default=1) default: bool = Field(default=True) class Group(BaseModel): group_name: str = Field(min_length=1, default="default_group") group_id: str = Field(default_factory=generate_group_id, init=False) - free_calls: int = Field(ge=1) - free_call_signer_address: str - daemon_addresses: list[str] + free_calls: int = Field(ge=1, default=3) + free_call_signer_address: str = Field(default="") + daemon_addresses: list[str] = Field(default=[]) + endpoints: list[str] = Field(default=[]) + pricing: list[Pricing] = Field(default=[]) -class ServiceDescription(BaseModel): ... +class ServiceDescription(BaseModel): + url: str = Field(default="") + short_description: str = Field(default="") + description: str = Field(default="") -class Media(BaseModel): ... +class Media(BaseModel): + order: int = Field(ge=1, default=1) + url: str = Field(min_length=1) + file_type: FileType + alt_text: str = Field(default="") + asset_type: AssetType -class Contributor(BaseModel): ... +class Contributor(BaseModel): + name: str = Field(min_length=1) + email_id: str = Field(default="") class ServiceMetadata(BaseModel): version: int = Field(ge=1, default=1) - display_name: str + display_name: str = Field(min_length=1) encoding: Literal["proto", "json"] = Field(default="proto") service_type: Literal["grpc", "http", "jsonrpc"] - service_api_source: Optional[str] = Field(min_length=1, default=None, init=False) + service_api_source: Optional[str] = Field(default=None, init=False) model_ipfs_hash: Optional[str] = Field(min_length=1, default=None, init=False, deprecated=True) - mpe_address: str - groups: list[Group] - service_description: ServiceDescription - media: list[Media] - contributors: list[Contributor] - tags: list[str] + mpe_address: Optional[str] = Field(default=None, init=False) + groups: list[Group] = Field(default=[]) + service_description: Optional[ServiceDescription] = Field(default=None) + media: list[Media] = Field(default=[]) + contributors: list[Contributor] = Field(default=[]) + tags: list[str] = Field(default=[]) @model_validator(mode="before") @classmethod @@ -117,7 +148,6 @@ def restrict_deprecated_fields(cls, data: Any, info: ValidationInfo) -> Any: is_fetching = info.context and info.context.get("from_storage") is True if not is_fetching and data.get("model_ipfs_hash"): - # TODO: configure exceptions raise ValueError( "The 'model_ipfs_hash' field is deprecated and cannot be used " "to create new metadata. Please use 'service_api_source' instead." @@ -128,10 +158,22 @@ def restrict_deprecated_fields(cls, data: Any, info: ValidationInfo) -> Any: def generate_final_json(self): if self.service_api_source is None: if self.model_ipfs_hash is None: - # TODO: configure exceptions raise ValueError("The 'service_api_source' field is missing!") else: - self.service_api_source, self.model_ipfs_hash = self.model_ipfs_hash, None + self.service_api_source = FileURI.normalize_string_uri(self.model_ipfs_hash) + self.model_ipfs_hash = None + + if not self.mpe_address: + raise ValueError("The 'mpe_address' field is missing!") + + if len(self.groups) == 0: + raise ValueError("There must be one item in 'groups' field at least!") + + if len(self.contributors) == 0: + raise ValueError("There must be one item in 'contributors' field at least!") + + if not self.service_description: + raise ValueError("The 'mpe_address' field is missing!") return self.model_dump_json(indent=2, exclude_none=True) diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py index f11d698..e40fc13 100644 --- a/snet/sdk/registry/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -1,15 +1,16 @@ -from typing import Any +import io +import tarfile +from pathlib import Path +from typing import Union +import ipfshttpclient from lighthouseweb3 import Lighthouse import json +import multihash +import hashlib from snet.sdk.registry.registry_contract import RegistryContract -from snet.sdk.types import StorageType, FileURI -from snet.sdk.utils.ipfs_utils import ( - get_ipfs_client, - get_from_ipfs_and_checkhash, -) -from snet.sdk.utils.utils import bytesuri_to_hash, safe_extract_proto +from snet.sdk.registry.models import StorageType, FileURI from snet.sdk.registry.service_metadata import ( MPEServiceMetadata, mpe_service_metadata_from_json, @@ -20,8 +21,8 @@ class StorageProvider(object): def __init__(self, registry_contract: RegistryContract): self._registry_contract = registry_contract - self._ipfs_client = get_ipfs_client() - self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) + self._ipfs_client = ipfshttpclient.connect(config.IPFS_ENDPOINT) + self._lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) def fetch_org_metadata(self, org_id): org = self._registry_contract.get_org(org_id) @@ -54,23 +55,75 @@ def enhance_service_metadata(self, org_id, service_id): return service_metadata def fetch_and_extract_proto(self, service_api_source, proto_dir): - try: - tar_uri = bytesuri_to_hash(service_api_source, to_decode=False) - except Exception: - # TODO: change exception based on bytesuri_to_hash function - tar_uri = FileURI(storage_type=StorageType.IPFS, uri_hash=service_api_source) - + tar_uri = FileURI.from_raw_uri(service_api_source) spec_tar = self._get_from_storage(tar_uri) - - safe_extract_proto(spec_tar, proto_dir) - - def _get_from_storage(self, uri: FileURI) -> Any: - if uri.storage_type == StorageType.IPFS: - file = get_from_ipfs_and_checkhash(self._ipfs_client, uri.uri_hash) - elif uri.storage_type == StorageType.FILECOIN: - file, _ = self.lighthouse_client.download(uri.uri_hash) - else: - # TODO: configure exceptions - raise Exception() - + self.safe_extract_proto(spec_tar, proto_dir) + + def _get_from_storage(self, uri: FileURI, decode: bool = True) -> Union[bytes, str]: + match uri.storage_type: + case StorageType.IPFS: + file = self._get_from_ipfs_and_checkhash(uri.uri_hash) + case StorageType.FILECOIN: + file, _ = self._lighthouse_client.download(uri.uri_hash) + case _: + raise ValueError(f"Unsupported storage type: {uri.storage_type}") + if decode: + file = file.decode() return file + + def _get_from_ipfs_and_checkhash(self, ipfs_hash: str, validate: bool = False) -> bytes: + data = self._ipfs_client.cat(ipfs_hash) + + if validate: + block_data = self._ipfs_client.block.get(ipfs_hash) + + try: + mh_bytes = multihash.from_b58_string(ipfs_hash) + decoded = multihash.decode(mh_bytes) + + hash_func_name = decoded.name + expected_digest = decoded.digest + + if hash_func_name == "sha2-256": + actual_digest = hashlib.sha256(block_data).digest() + else: + h = hashlib.new(hash_func_name.replace("-", "")) + h.update(block_data) + actual_digest = h.digest() + + if actual_digest != expected_digest: + raise Exception("IPFS hash mismatch with data") + + except Exception as e: + raise ValueError(f"Integrity check failed: {str(e)}") from e + + return data + + @staticmethod + def safe_extract_proto(spec_tar: bytes, proto_dir: Union[str, Path]) -> None: + dest_dir = Path(proto_dir).resolve() + dest_dir.mkdir(parents=True, exist_ok=True) + + with tarfile.open(fileobj=io.BytesIO(spec_tar)) as f: + valid_members = [] + + for m in f.getmembers(): + if not m.isfile(): + raise ValueError( + f"Security/Format Error: Tarball contains a non-file item: '{m.name}'" + ) + if Path(m.name).parent != Path("."): + raise ValueError( + f"Format Error: Tarball contains nested paths ('{m.name}'). Only flat archives are supported." + ) + if not m.name.endswith(".proto"): + raise ValueError( + f"Format Error: Unexpected file type '{m.name}'. Only .proto files allowed." + ) + target_file = dest_dir / m.name + if target_file.exists(): + target_file.unlink() + print(f"Removed existing file: {target_file}") + valid_members.append(m) + + f.extractall(path=dest_dir, members=valid_members, filter="data") diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index c0faa87..30ac854 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -22,7 +22,7 @@ ) from snet.sdk.resources.root_certificate import certificate from snet.sdk.registry.service_metadata import MPEServiceMetadata -from snet.sdk.custom_typing import ModuleName, ServiceStub +from snet.sdk.types import ModuleName, ServiceStub from snet.sdk.utils.utils import ( RESOURCES_PATH, add_to_path, diff --git a/snet/sdk/types.py b/snet/sdk/types.py index 216ed18..81ed725 100644 --- a/snet/sdk/types.py +++ b/snet/sdk/types.py @@ -1,44 +1,5 @@ -from dataclasses import dataclass -from enum import Enum +from typing import Any, NewType -@dataclass -class RawOrgData: - org_id: bytes - metadata_uri: bytes - owner: str - members: list[str] - services: list[bytes] # IDs - - -@dataclass -class RawServiceData: - service_id: bytes - metadata_uri: bytes - - -class StorageType(Enum): - IPFS = "ipfs" - FILECOIN = "filecoin" - - -@dataclass -class FileURI: - storage_type: StorageType - uri_hash: str - - -@dataclass -class OrgData: - org_id: str - metadata_uri: FileURI - owner: str - members: list[str] - services: list[str] # IDs - - -@dataclass -class ServiceData: - org_id: str - service_id: str - metadata_uri: FileURI +ModuleName = NewType("ModuleName", str) +ServiceStub = NewType("ServiceStub", Any) diff --git a/snet/sdk/utils/ipfs_utils.py b/snet/sdk/utils/ipfs_utils.py deleted file mode 100644 index a3e76bd..0000000 --- a/snet/sdk/utils/ipfs_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -import ipfshttpclient -import multihash -import hashlib - -from snet.sdk.config import config - - -def get_from_ipfs_and_checkhash(ipfs_client, ipfs_hash_base58, validate=True): - """ - Get file from IPFS and validate hash - """ - data = ipfs_client.cat(ipfs_hash_base58) - - if validate: - block_data = ipfs_client.block.get(ipfs_hash_base58) - - try: - mh_bytes = multihash.from_b58_string(ipfs_hash_base58) - decoded = multihash.decode(mh_bytes) - - hash_func_name = decoded.name - expected_digest = decoded.digest - - if hash_func_name == "sha2-256": # Standard for IPFS (CIDv0) - actual_digest = hashlib.sha256(block_data).digest() - else: - # Handle other algorithms supported by hashlib if necessary - h = hashlib.new(hash_func_name.replace("-", "")) - h.update(block_data) - actual_digest = h.digest() - - if actual_digest != expected_digest: - raise Exception("IPFS hash mismatch with data") - - except Exception as e: - raise ValueError(f"Integrity check failed: {str(e)}") from e - - return data - - -def get_ipfs_client(): - return ipfshttpclient.connect(config.IPFS_ENDPOINT) diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 9e4c876..33e80ff 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -2,12 +2,10 @@ import sys import importlib.resources from functools import lru_cache -from typing import Optional, Union +from typing import Optional from urllib.parse import urlparse from pathlib import Path, PurePath import os -import tarfile -import io import web3 from eth_typing import BlockNumber @@ -16,31 +14,10 @@ from snet import sdk from snet.sdk.config import config -from snet.sdk.types import StorageType, FileURI, RawOrgData, OrgData, RawServiceData, ServiceData RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") -def convert_raw_org_data(raw_org_data: RawOrgData) -> OrgData: - return OrgData( - org_id=bytes32_to_str(raw_org_data.org_id), - metadata_uri=bytesuri_to_hash(raw_org_data.metadata_uri), - owner=raw_org_data.owner, - members=raw_org_data.members, - services=list(map(bytes32_to_str, raw_org_data.services)), - ) - - -def convert_raw_service_data( - raw_service_data: RawServiceData, org_id: Union[str, bytes] -) -> ServiceData: - return ServiceData( - org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, - service_id=bytes32_to_str(raw_service_data.service_id), - metadata_uri=bytesuri_to_hash(raw_service_data.metadata_uri), - ) - - @lru_cache def get_we3_object(eth_rpc_endpoint: Optional[str] = None) -> Web3: if eth_rpc_endpoint is None: @@ -194,34 +171,3 @@ def find_file_by_keyword(directory, keyword, exclude=None): for file in files: if keyword in file and all(e not in file for e in exclude): return file - - -def bytesuri_to_hash(s, to_decode=True): - if to_decode: - s = s.rstrip(b"\0").decode("ascii") - try: - storage_type, storage_hash = s.split("://") - return FileURI(StorageType(storage_type), storage_hash) - except ValueError: - # TODO: configure exceptions - raise Exception("We support only ipfs and filecoin uri in Registry") - - -def safe_extract_proto(spec_tar, protodir): - """ - Tar files might be dangerous (see https://bugs.python.org/issue21109, - and https://docs.python.org/3/library/tarfile.html, TarFile.extractall warning) - we extract only simple files - """ - with tarfile.open(fileobj=io.BytesIO(spec_tar)) as f: - for m in f.getmembers(): - if os.path.dirname(m.name) != "": - raise Exception("tarball has directories. We do not support it.") - if not m.isfile(): - raise Exception("tarball contains %s which is not a file" % m.name) - fullname = os.path.join(protodir, m.name) - if os.path.exists(fullname): - os.remove(fullname) - print(f"{fullname} removed.") - # now it is safe to call extractall - f.extractall(path=protodir) From 8ac6a69cf48fa60c43cc42d6863dc858143012fb Mon Sep 17 00:00:00 2001 From: Arondondon Date: Fri, 20 Mar 2026 18:34:08 +0300 Subject: [PATCH 7/7] feat: implement org metadata, end reworking storage provider --- snet/sdk/__init__.py | 27 +++++++-- snet/sdk/registry/organization_metadata.py | 66 ++++++++++++++++++++++ snet/sdk/registry/service_metadata.py | 26 ++------- snet/sdk/registry/storage_provider.py | 42 ++++---------- 4 files changed, 106 insertions(+), 55 deletions(-) create mode 100644 snet/sdk/registry/organization_metadata.py diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 2b6825c..35c248a 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -7,6 +7,7 @@ import google.protobuf.internal.api_implementation from google.protobuf import symbol_database as _symbol_database +from snet.sdk.registry.organization_metadata import OrganizationMetadata from snet.sdk.registry.registry_contract import RegistryContract from snet.sdk.registry.service_metadata import MPEServiceMetadata @@ -56,7 +57,7 @@ def __init__(self): self.w3 = get_we3_object() self.mpe_contract = MPEContract() self.registry_contract = RegistryContract() - self.metadata_provider = StorageProvider(self.registry_contract) + self.storage_provider = StorageProvider(self.registry_contract) self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract) self.account = Account() @@ -73,7 +74,7 @@ def create_service_client( # Create and instance of the Config object, # so we can create an instance of ClientLibGenerator - lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id) + lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id) # Download the proto file and generate stubs if needed force_update = config.FORCE_UPDATE @@ -99,7 +100,7 @@ def create_service_client( if payment_strategy is None: payment_strategy = payment_strategy_type.value() - service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id) + service_metadata = self._enhance_service_metadata(org_id, service_id) group = self._get_service_group_details(service_metadata, group_name) service_stubs = self.get_service_stub(lib_generator) @@ -123,6 +124,19 @@ def create_service_client( ) return _service_client + def _enhance_service_metadata(self, org_id, service_id): + service_metadata = self.get_service_metadata(org_id, service_id) + org_metadata = self.get_organization_metadata(org_id) + + org_group_map = {} + for group in org_metadata.groups: + org_group_map[group.group_name] = group + + for group in service_metadata.groups: + group.payment = org_group_map[group.group_name].payment + + return service_metadata + def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]: path_to_pb_files = str(lib_generator.proto_dir) module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator) @@ -146,7 +160,12 @@ def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) return ModuleName(module_name) def get_service_metadata(self, org_id, service_id): - return self.metadata_provider.fetch_service_metadata(org_id, service_id) + service = self.registry_contract.get_service(org_id, service_id) + return self.storage_provider.fetch_service_metadata(service.metadata_uri) + + def get_organization_metadata(self, org_id: str) -> OrganizationMetadata: + org = self.registry_contract.get_org(org_id) + return self.storage_provider.fetch_org_metadata(org.metadata_uri) def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict: return service_metadata["groups"][0] diff --git a/snet/sdk/registry/organization_metadata.py b/snet/sdk/registry/organization_metadata.py new file mode 100644 index 0000000..421d1f5 --- /dev/null +++ b/snet/sdk/registry/organization_metadata.py @@ -0,0 +1,66 @@ +from typing import Optional, Literal + +from pydantic import BaseModel, Field + + +class Description(BaseModel): + url: str = Field(default="") + description: str = Field(min_length=1) + short_description: str = Field(min_length=1, max_length=160) + + +class Assets(BaseModel): + hero_image: str = Field(default="") + + +class Contact(BaseModel): + email: str = Field(default="") + phone: str = Field(default="") + contact_type: Literal["general", "support"] + + +class PaymentChannelStorageClient(BaseModel): + connection_timeout: str = Field(default="5s") + request_timeout: str = Field(default="5s") + endpoints: list[str] = Field(default=[]) + + +class Payment(BaseModel): + payment_address: str + payment_expiration_threshold: int = Field(default=40320) + payment_channel_storage_type: Literal["etcd"] = Field(default="etcd") + payment_channel_storage_client: Optional[PaymentChannelStorageClient] = Field(default=None) + + +class Group(BaseModel): + group_name: str + group_id: str + payment: Payment + + +class OrganizationMetadata(BaseModel): + org_name: str = Field(min_length=1) + org_id: Optional[str] = Field(default=None, init=False) + org_type: Literal["organization", "individual"] + description: Optional[Description] = Field(default=None) + assets: Optional[Assets] = Field(default=None) + contacts: list[Contact] = Field(default=[]) + groups: list[Group] = Field(default=[]) + + def generate_final_json(self): + if not self.org_id: + raise ValueError("The 'org_id' field is missing!") + + if not self.description: + raise ValueError("The 'description' field is missing!") + + if not self.assets: + raise ValueError("The 'assets' field is missing!") + + if not self.contacts: + raise ValueError("The 'contacts' field is missing!") + + if len(self.groups) == 0: + raise ValueError("There must be one item in 'groups' field at least!") + + return self.model_dump_json(indent=2, exclude_none=True) diff --git a/snet/sdk/registry/service_metadata.py b/snet/sdk/registry/service_metadata.py index 97de7ca..be0966c 100644 --- a/snet/sdk/registry/service_metadata.py +++ b/snet/sdk/registry/service_metadata.py @@ -48,6 +48,7 @@ from pydantic import BaseModel, Field, model_validator, ValidationInfo from snet.sdk.registry.models import FileURI +from snet.sdk.registry.organization_metadata import Payment from snet.sdk.utils.utils import is_valid_endpoint @@ -68,24 +69,6 @@ def is_single_value(asset_type): return True -class FileType(Enum): - IMAGE = "image" - VIDEO = "video" - ARCHIVE = "archive" - - -# class AssetType(Enum): -# HERO_IMAGE = "hero_image" -# PROTO_FILE = "proto_file" -# DEMO_COMPONENT = "demo_component" - - -class ServiceType(Enum): - GRPC = "grpc" - HTTP = "http" - JSONRPC = "jsonrpc" - - def generate_group_id() -> str: return base64.b64encode(secrets.token_bytes(32)).decode() @@ -104,6 +87,9 @@ class Group(BaseModel): daemon_addresses: list[str] = Field(default=[]) endpoints: list[str] = Field(default=[]) pricing: list[Pricing] = Field(default=[]) + payment: Optional[Payment] = Field( + default=None + ) # The field from org metadata for service client functionality class ServiceDescription(BaseModel): @@ -115,9 +101,9 @@ class ServiceDescription(BaseModel): class Media(BaseModel): order: int = Field(ge=1, default=1) url: str = Field(min_length=1) - file_type: FileType + file_type: Literal["image", "video", "archive"] alt_text: str = Field(default="") - asset_type: AssetType + asset_type: Literal["hero_image", "proto_file", "demo_component"] class Contributor(BaseModel): diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py index e40fc13..9b26b5a 100644 --- a/snet/sdk/registry/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -9,48 +9,28 @@ import multihash import hashlib -from snet.sdk.registry.registry_contract import RegistryContract +from snet.sdk.registry.organization_metadata import OrganizationMetadata from snet.sdk.registry.models import StorageType, FileURI -from snet.sdk.registry.service_metadata import ( - MPEServiceMetadata, - mpe_service_metadata_from_json, -) +from snet.sdk.registry.service_metadata import ServiceMetadata from snet.sdk.config import config class StorageProvider(object): - def __init__(self, registry_contract: RegistryContract): - self._registry_contract = registry_contract + def __init__(self): self._ipfs_client = ipfshttpclient.connect(config.IPFS_ENDPOINT) self._lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) - def fetch_org_metadata(self, org_id): - org = self._registry_contract.get_org(org_id) - - org_metadata_json = self._get_from_storage(org.metadata_uri) - org_metadata = json.loads(org_metadata_json) + def fetch_org_metadata(self, metadata_uri: FileURI): + org_metadata_json = self._get_from_storage(metadata_uri) + raw_org_metadata = json.loads(org_metadata_json) + org_metadata = OrganizationMetadata(**raw_org_metadata) return org_metadata - def fetch_service_metadata(self, org_id: str, service_id: str) -> MPEServiceMetadata: - service = self._registry_contract.get_service(org_id, service_id) - - service_metadata_json = self._get_from_storage(service.metadata_uri) - service_metadata = mpe_service_metadata_from_json(service_metadata_json) - - return service_metadata - - def enhance_service_metadata(self, org_id, service_id): - service_metadata = self.fetch_service_metadata(org_id, service_id) - org_metadata = self.fetch_org_metadata(org_id) - - org_group_map = {} - for group in org_metadata["groups"]: - org_group_map[group["group_name"]] = group - - for group in service_metadata.m["groups"]: - # merge service group with org_group - group["payment"] = org_group_map[group["group_name"]]["payment"] + def fetch_service_metadata(self, metadata_uri: FileURI) -> ServiceMetadata: + service_metadata_json = self._get_from_storage(metadata_uri) + raw_service_metadata = json.loads(service_metadata_json) + service_metadata = ServiceMetadata(**raw_service_metadata) return service_metadata