diff --git a/src/fish_audio_sdk/schemas.py b/src/fish_audio_sdk/schemas.py index 81d6323..af40743 100644 --- a/src/fish_audio_sdk/schemas.py +++ b/src/fish_audio_sdk/schemas.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field -Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini"] +Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini", "s2-pro"] Item = TypeVar("Item") diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 3ec2343..0167ce6 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -28,6 +28,7 @@ TTSConfig, TTSRequest, ) +from fishaudio.types.shared import warn_if_deprecated_model from .realtime import aiter_websocket_audio, iter_websocket_audio @@ -81,7 +82,7 @@ def stream( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", request_options: Optional[RequestOptions] = None, ) -> AudioStream: """ @@ -115,6 +116,8 @@ def stream( audio = client.tts.stream(text="Hello world").collect() ``` """ + warn_if_deprecated_model(model) + # Build request payload from config request = _config_to_tts_request(config, text) @@ -163,7 +166,7 @@ def convert( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", request_options: Optional[RequestOptions] = None, ) -> bytes: """ @@ -225,7 +228,7 @@ def stream_websocket( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", max_workers: int = 10, ws_options: Optional[WebSocketOptions] = None, ) -> Iterator[bytes]: @@ -310,6 +313,8 @@ def text_generator(): f.write(audio_chunk) ``` """ + warn_if_deprecated_model(model) + # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") @@ -381,7 +386,7 @@ async def stream( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", request_options: Optional[RequestOptions] = None, ) -> AsyncAudioStream: """ @@ -416,6 +421,8 @@ async def stream( audio = await stream.collect() ``` """ + warn_if_deprecated_model(model) + # Build request payload from config request = _config_to_tts_request(config, text) @@ -464,7 +471,7 @@ async def convert( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", request_options: Optional[RequestOptions] = None, ) -> bytes: """ @@ -527,7 +534,7 @@ async def stream_websocket( latency: Optional[LatencyMode] = None, speed: Optional[float] = None, config: TTSConfig = TTSConfig(), - model: Model = "s1", + model: Model = "s2-pro", ws_options: Optional[WebSocketOptions] = None, ): """ @@ -610,6 +617,8 @@ async def text_generator(): await f.write(audio_chunk) ``` """ + warn_if_deprecated_model(model) + # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") diff --git a/src/fishaudio/types/shared.py b/src/fishaudio/types/shared.py index 427d5da..879a012 100644 --- a/src/fishaudio/types/shared.py +++ b/src/fishaudio/types/shared.py @@ -1,5 +1,6 @@ """Shared types used across the SDK.""" +import warnings from typing import Generic, Literal, TypeVar from pydantic import BaseModel @@ -21,7 +22,21 @@ class PaginatedResponse(BaseModel, Generic[T]): # Model types -Model = Literal["speech-1.5", "speech-1.6", "s1"] +Model = Literal["speech-1.5", "speech-1.6", "s1", "s2-pro"] + +# Deprecated models +DEPRECATED_MODELS = {"speech-1.5", "speech-1.6"} + + +def warn_if_deprecated_model(model: str) -> None: + """Emit a deprecation warning if a legacy model is used.""" + if model in DEPRECATED_MODELS: + warnings.warn( + f"Model '{model}' is deprecated. Use 's1' or 's2-pro' instead.", + DeprecationWarning, + stacklevel=3, + ) + # Audio format types AudioFormat = Literal["wav", "pcm", "mp3", "opus"] diff --git a/tests/integration/test_tts_integration.py b/tests/integration/test_tts_integration.py index 19e511a..5d4e91e 100644 --- a/tests/integration/test_tts_integration.py +++ b/tests/integration/test_tts_integration.py @@ -45,6 +45,7 @@ def test_tts_with_prosody(self, client, save_audio): # Write to output directory save_audio(audio, "test_prosody.mp3") + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_tts_with_different_models(self, client, save_audio): """Test TTS with different models.""" models = get_args(Model) diff --git a/tests/integration/test_tts_websocket_integration.py b/tests/integration/test_tts_websocket_integration.py index 1dc1604..ee9aa15 100644 --- a/tests/integration/test_tts_websocket_integration.py +++ b/tests/integration/test_tts_websocket_integration.py @@ -6,7 +6,7 @@ from fishaudio import WebSocketOptions from fishaudio.types import FlushEvent, Prosody, TextEvent, TTSConfig -from fishaudio.types.shared import Model +from fishaudio.types.shared import DEPRECATED_MODELS, Model from .conftest import TEST_REFERENCE_ID @@ -35,6 +35,7 @@ def text_stream(): # Save the audio save_audio(audio_chunks, "test_websocket_streaming.mp3") + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( "model", [ @@ -44,7 +45,7 @@ def text_stream(): reason="WebSocket unreliable for legacy models" ), ) - if not m.startswith("s1") + if m in DEPRECATED_MODELS else m for m in get_args(Model) ], @@ -137,16 +138,14 @@ def text_stream(): save_audio(audio_chunks, "test_websocket_reference.mp3") def test_websocket_streaming_empty_text(self, client, save_audio): - """Test WebSocket streaming with empty text stream raises error.""" - from fishaudio.exceptions import WebSocketError + """Test WebSocket streaming with empty text stream completes without error.""" def text_stream(): return yield # Make it a generator - # Empty stream should raise WebSocketError as API returns error - with pytest.raises(WebSocketError, match="WebSocket stream ended with error"): - list(client.tts.stream_websocket(text_stream())) + audio_chunks = list(client.tts.stream_websocket(text_stream())) + assert isinstance(audio_chunks, list) def test_websocket_very_long_generation_with_timeout(self, client, save_audio): """ @@ -223,6 +222,7 @@ async def text_stream(): save_audio(audio_chunks, "test_async_websocket_streaming.mp3") + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.asyncio @pytest.mark.parametrize( "model", @@ -233,7 +233,7 @@ async def text_stream(): reason="WebSocket unreliable for legacy models" ), ) - if not m.startswith("s1") + if m in DEPRECATED_MODELS else m for m in get_args(Model) ], @@ -366,14 +366,13 @@ async def text_stream(): @pytest.mark.asyncio async def test_async_websocket_streaming_empty_text(self, async_client, save_audio): - """Test async WebSocket streaming with empty text stream raises error.""" - from fishaudio.exceptions import WebSocketError + """Test async WebSocket streaming with empty text stream completes without error.""" async def text_stream(): return yield # Make it an async generator - # Empty stream should raise WebSocketError as API returns error - with pytest.raises(WebSocketError, match="WebSocket stream ended with error"): - async for chunk in async_client.tts.stream_websocket(text_stream()): - pass + audio_chunks = [] + async for chunk in async_client.tts.stream_websocket(text_stream()): + audio_chunks.append(chunk) + assert isinstance(audio_chunks, list) diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index de62520..132038a 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -63,7 +63,7 @@ def test_stream_basic(self, tts_client, mock_client_wrapper): # Check headers assert call_args[1]["headers"]["Content-Type"] == "application/msgpack" - assert call_args[1]["headers"]["model"] == "s1" # default model + assert call_args[1]["headers"]["model"] == "s2-pro" # default model # Check payload was msgpack encoded assert "content" in call_args[1] diff --git a/tests/unit/test_tts_realtime.py b/tests/unit/test_tts_realtime.py index eaf9210..c3b60a9 100644 --- a/tests/unit/test_tts_realtime.py +++ b/tests/unit/test_tts_realtime.py @@ -93,6 +93,7 @@ def test_stream_websocket_basic( mock_connect_ws.assert_called_once() assert mock_connect_ws.call_args[0][0] == "/v1/tts/live" + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @patch("fishaudio.resources.tts.connect_ws") @patch("fishaudio.resources.tts.ThreadPoolExecutor") def test_stream_websocket_with_config( @@ -425,6 +426,7 @@ async def text_stream(): mock_aconnect_ws.assert_called_once() assert mock_aconnect_ws.call_args[0][0] == "/v1/tts/live" + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.asyncio @patch("fishaudio.resources.tts.aconnect_ws") async def test_stream_websocket_with_config(