Files
ss-tools/superset_tool/utils/network.py
Volobuev Andrey c0a6ca7769 2
2025-06-27 15:20:29 +03:00

215 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional, Dict, Any,BinaryIO,List
import requests
import json
import urllib3
from ..exceptions import AuthenticationError, NetworkError,DashboardNotFoundError,SupersetAPIError,PermissionDeniedError
class APIClient:
"""[NETWORK-CORE] Инкапсулирует HTTP-логику для работы с API.
@contract: Гарантирует retry, SSL-валидацию и стандартные заголовки.
"""
def __init__(
self,
base_url: str,
auth: Dict[str, Any],
verify_ssl: bool = False,
timeout: int = 30
):
self.base_url = base_url
self.auth = auth
self.session = self._init_session(verify_ssl)
self.timeout = timeout
def _init_session(self, verify_ssl: bool) -> requests.Session:
"""[NETWORK-INIT] Настройка сессии с адаптерами."""
session = requests.Session()
session.mount('https://', requests.adapters.HTTPAdapter(max_retries=3))
session.verify = verify_ssl
if not verify_ssl:
urllib3.disable_warnings()
return session
def authenticate(self) -> Dict[str, str]:
"""[AUTH-FLOW] Получение access и CSRF токенов."""
try:
response = self.session.post(
f"{self.base_url}/security/login",
json={**self.auth, "provider": "db", "refresh": True},
timeout=self.timeout
)
response.raise_for_status()
access_token = response.json()["access_token"]
csrf_response = self.session.get(
f"{self.base_url}/security/csrf_token/",
headers={"Authorization": f"Bearer {access_token}"},
timeout=self.timeout
)
csrf_response.raise_for_status()
return {
"access_token": access_token,
"csrf_token": csrf_response.json()["result"]
}
except requests.exceptions.RequestException as e:
raise NetworkError(f"Auth failed: {str(e)}")
def request(
self,
method: str,
endpoint: str,
headers: Optional[Dict] = None,
**kwargs
) -> requests.Response:
"""[NETWORK-CORE] Обертка для запросов с обработкой ошибок."""
try:
response = self.session.request(
method,
f"{self.base_url}{endpoint}",
headers=headers,
timeout=self.timeout,
**kwargs
)
response.raise_for_status()
return response
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise DashboardNotFoundError(endpoint)
raise SupersetAPIError(str(e))
def upload_file(
self,
endpoint: str,
file_obj: BinaryIO,
file_name: str,
form_field: str = "file",
extra_data: Optional[Dict] = None,
timeout: Optional[int] = None
) -> Dict:
"""[NETWORK] Отправка файла на сервер
@params:
endpoint: API endpoint
file_obj: файловый объект
file_name: имя файла
form_field: имя поля формы
extra_data: дополнительные данные
timeout: таймаут запроса
@return:
Ответ сервера (JSON)
"""
files = {form_field: (file_name, file_obj, 'application/x-zip-compressed')}
headers = {
k: v for k, v in self.headers.items()
if k.lower() != 'content-type'
}
try:
response = self.session.post(
url=f"{self.base_url}{endpoint}",
files=files,
data=extra_data or {},
headers=headers,
timeout=timeout or self.timeout
)
if response.status_code == 403:
raise PermissionDeniedError("Доступ запрещен")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
error_ctx = {
"endpoint": endpoint,
"file": file_name,
"status_code": getattr(e.response, 'status_code', None)
}
self.logger.error(
"[NETWORK_ERROR] Ошибка загрузки файла",
extra=error_ctx
)
raise
def fetch_paginated_count(
self,
endpoint: str,
query_params: Dict,
count_field: str = "count",
timeout: Optional[int] = None
) -> int:
"""[NETWORK] Получение общего количества элементов в пагинированном API
@params:
endpoint: API endpoint без query-параметров
query_params: параметры для пагинации
count_field: поле с количеством в ответе
timeout: таймаут запроса
@return:
Общее количество элементов
@errors:
- NetworkError: проблемы с соединением
- KeyError: некорректный формат ответа
"""
try:
response = self.request(
method="GET",
endpoint=endpoint,
params={"q": json.dumps(query_params)},
timeout=timeout or self.timeout
)
if count_field not in response:
raise KeyError(f"Ответ API не содержит поле {count_field}")
return response[count_field]
except requests.exceptions.RequestException as e:
error_ctx = {
"endpoint": endpoint,
"params": query_params,
"error": str(e)
}
self.logger.error("[PAGINATION_ERROR]", extra=error_ctx)
raise NetworkError(f"Ошибка пагинации: {str(e)}") from e
def fetch_paginated_data(
self,
endpoint: str,
base_query: Dict,
total_count: int,
results_field: str = "result",
timeout: Optional[int] = None
) -> List[Any]:
"""[NETWORK] Получение всех данных с пагинированного API
@params:
endpoint: API endpoint
base_query: базовые параметры запроса (без page)
total_count: общее количество элементов
results_field: поле с данными в ответе
timeout: таймаут для запросов
@return:
Собранные данные со всех страниц
"""
page_size = base_query['page_size']
total_pages = (total_count + page_size - 1) // page_size
results = []
for page in range(total_pages):
query = {**base_query, 'page': page}
response = self._execute_request(
method="GET",
endpoint=endpoint,
params={"q": json.dumps(query)},
timeout=timeout or self.timeout
)
if results_field not in response:
self.logger.warning(
f"Ответ не содержит поле {results_field}",
extra={"response": response.keys()}
)
continue
results.extend(response[results_field])
return results