diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..6f6184afb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -81,6 +81,7 @@ def __init__( int ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + attention_type: int = llama_cpp.LLAMA_ATTENTION_TYPE_UNSPECIFIED, rope_freq_base: float = 0.0, rope_freq_scale: float = 0.0, yarn_ext_factor: float = -1.0, @@ -319,6 +320,7 @@ def __init__( else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ) self.context_params.pooling_type = pooling_type + self.context_params.attention_type = attention_type self.context_params.rope_freq_base = ( rope_freq_base if rope_freq_base != 0.0 else 0 )