265 lines
12 KiB
Python
265 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
# pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches,unused-argument
|
||
"""
|
||
[MODULE] Сетевой клиент для API
|
||
|
||
[DESCRIPTION]
|
||
Инкапсулирует низкоуровневую HTTP-логику для взаимодействия с Superset API.
|
||
"""
|
||
|
||
# [IMPORTS] Стандартная библиотека
|
||
from typing import Optional, Dict, Any, BinaryIO, List, Union
|
||
import json
|
||
import io
|
||
from pathlib import Path
|
||
|
||
# [IMPORTS] Сторонние библиотеки
|
||
import requests
|
||
import urllib3 # Для отключения SSL-предупреждений
|
||
|
||
# [IMPORTS] Локальные модули
|
||
from superset_tool.exceptions import (
|
||
AuthenticationError,
|
||
NetworkError,
|
||
DashboardNotFoundError,
|
||
SupersetAPIError,
|
||
PermissionDeniedError
|
||
)
|
||
from superset_tool.utils.logger import SupersetLogger # Импорт логгера
|
||
|
||
# [CONSTANTS]
|
||
DEFAULT_RETRIES = 3
|
||
DEFAULT_BACKOFF_FACTOR = 0.5
|
||
DEFAULT_TIMEOUT = 30
|
||
|
||
class APIClient:
|
||
"""[NETWORK-CORE] Инкапсулирует HTTP-логику для работы с API."""
|
||
|
||
def __init__(
|
||
self,
|
||
config: Dict[str, Any],
|
||
verify_ssl: bool = True,
|
||
timeout: int = DEFAULT_TIMEOUT,
|
||
logger: Optional[SupersetLogger] = None
|
||
):
|
||
self.logger = logger or SupersetLogger(name="APIClient")
|
||
self.logger.info("[INFO][APIClient.__init__][ENTER] Initializing APIClient.")
|
||
self.base_url = config.get("base_url")
|
||
self.auth = config.get("auth")
|
||
self.request_settings = {
|
||
"verify_ssl": verify_ssl,
|
||
"timeout": timeout
|
||
}
|
||
self.session = self._init_session()
|
||
self._tokens: Dict[str, str] = {}
|
||
self._authenticated = False
|
||
self.logger.info("[INFO][APIClient.__init__][SUCCESS] APIClient initialized.")
|
||
|
||
def _init_session(self) -> requests.Session:
|
||
self.logger.debug("[DEBUG][APIClient._init_session][ENTER] Initializing session.")
|
||
session = requests.Session()
|
||
retries = requests.adapters.Retry(
|
||
total=DEFAULT_RETRIES,
|
||
backoff_factor=DEFAULT_BACKOFF_FACTOR,
|
||
status_forcelist=[500, 502, 503, 504],
|
||
allowed_methods={"HEAD", "GET", "POST", "PUT", "DELETE"}
|
||
)
|
||
adapter = requests.adapters.HTTPAdapter(max_retries=retries)
|
||
session.mount('http://', adapter)
|
||
session.mount('https://', adapter)
|
||
verify_ssl = self.request_settings.get("verify_ssl", True)
|
||
session.verify = verify_ssl
|
||
if not verify_ssl:
|
||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||
self.logger.warning("[WARNING][APIClient._init_session][STATE_CHANGE] SSL verification disabled.")
|
||
self.logger.debug("[DEBUG][APIClient._init_session][SUCCESS] Session initialized.")
|
||
return session
|
||
|
||
def authenticate(self) -> Dict[str, str]:
|
||
self.logger.info(f"[INFO][APIClient.authenticate][ENTER] Authenticating to {self.base_url}")
|
||
try:
|
||
login_url = f"{self.base_url}/security/login"
|
||
response = self.session.post(
|
||
login_url,
|
||
json=self.auth,
|
||
timeout=self.request_settings.get("timeout", DEFAULT_TIMEOUT)
|
||
)
|
||
response.raise_for_status()
|
||
access_token = response.json()["access_token"]
|
||
csrf_url = f"{self.base_url}/security/csrf_token/"
|
||
csrf_response = self.session.get(
|
||
csrf_url,
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
timeout=self.request_settings.get("timeout", DEFAULT_TIMEOUT)
|
||
)
|
||
csrf_response.raise_for_status()
|
||
csrf_token = csrf_response.json()["result"]
|
||
self._tokens = {
|
||
"access_token": access_token,
|
||
"csrf_token": csrf_token
|
||
}
|
||
self._authenticated = True
|
||
self.logger.info(f"[INFO][APIClient.authenticate][SUCCESS] Authenticated successfully. Tokens {self._tokens}")
|
||
return self._tokens
|
||
except requests.exceptions.HTTPError as e:
|
||
self.logger.error(f"[ERROR][APIClient.authenticate][FAILURE] Authentication failed: {e}")
|
||
raise AuthenticationError(f"Authentication failed: {e}") from e
|
||
except (requests.exceptions.RequestException, KeyError) as e:
|
||
self.logger.error(f"[ERROR][APIClient.authenticate][FAILURE] Network or parsing error: {e}")
|
||
raise NetworkError(f"Network or parsing error during authentication: {e}") from e
|
||
|
||
@property
|
||
def headers(self) -> Dict[str, str]:
|
||
if not self._authenticated:
|
||
self.authenticate()
|
||
return {
|
||
"Authorization": f"Bearer {self._tokens['access_token']}",
|
||
"X-CSRFToken": self._tokens.get("csrf_token", ""),
|
||
"Referer": self.base_url,
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
def request(
|
||
self,
|
||
method: str,
|
||
endpoint: str,
|
||
headers: Optional[Dict] = None,
|
||
raw_response: bool = False,
|
||
**kwargs
|
||
) -> Union[requests.Response, Dict[str, Any]]:
|
||
self.logger.debug(f"[DEBUG][APIClient.request][ENTER] Requesting {method} {endpoint}")
|
||
full_url = f"{self.base_url}{endpoint}"
|
||
_headers = self.headers.copy()
|
||
if headers:
|
||
_headers.update(headers)
|
||
timeout = kwargs.pop('timeout', self.request_settings.get("timeout", DEFAULT_TIMEOUT))
|
||
try:
|
||
response = self.session.request(
|
||
method,
|
||
full_url,
|
||
headers=_headers,
|
||
timeout=timeout,
|
||
**kwargs
|
||
)
|
||
response.raise_for_status()
|
||
self.logger.debug(f"[DEBUG][APIClient.request][SUCCESS] Request successful for {method} {endpoint}")
|
||
return response if raw_response else response.json()
|
||
except requests.exceptions.HTTPError as e:
|
||
self.logger.error(f"[ERROR][APIClient.request][FAILURE] HTTP error for {method} {endpoint}: {e}")
|
||
self._handle_http_error(e, endpoint, context={})
|
||
except requests.exceptions.RequestException as e:
|
||
self.logger.error(f"[ERROR][APIClient.request][FAILURE] Network error for {method} {endpoint}: {e}")
|
||
self._handle_network_error(e, full_url)
|
||
|
||
def _handle_http_error(self, e, endpoint, context):
|
||
status_code = e.response.status_code
|
||
if status_code == 404:
|
||
raise DashboardNotFoundError(endpoint, context=context) from e
|
||
if status_code == 403:
|
||
raise PermissionDeniedError("Доступ запрещен.", **context) from e
|
||
if status_code == 401:
|
||
raise AuthenticationError("Аутентификация не удалась.", **context) from e
|
||
raise SupersetAPIError(f"Ошибка API: {status_code} - {e.response.text}", **context) from e
|
||
|
||
def _handle_network_error(self, e, url):
|
||
if isinstance(e, requests.exceptions.Timeout):
|
||
msg = "Таймаут запроса"
|
||
elif isinstance(e, requests.exceptions.ConnectionError):
|
||
msg = "Ошибка соединения"
|
||
else:
|
||
msg = f"Неизвестная сетевая ошибка: {e}"
|
||
raise NetworkError(msg, url=url) from e
|
||
|
||
def upload_file(
|
||
self,
|
||
endpoint: str,
|
||
file_info: Dict[str, Any],
|
||
extra_data: Optional[Dict] = None,
|
||
timeout: Optional[int] = None
|
||
) -> Dict:
|
||
self.logger.info(f"[INFO][APIClient.upload_file][ENTER] Uploading file to {endpoint}")
|
||
full_url = f"{self.base_url}{endpoint}"
|
||
_headers = self.headers.copy()
|
||
_headers.pop('Content-Type', None)
|
||
file_obj = file_info.get("file_obj")
|
||
file_name = file_info.get("file_name")
|
||
form_field = file_info.get("form_field", "file")
|
||
if isinstance(file_obj, (str, Path)):
|
||
with open(file_obj, 'rb') as file_to_upload:
|
||
files_payload = {form_field: (file_name, file_to_upload, 'application/x-zip-compressed')}
|
||
return self._perform_upload(full_url, files_payload, extra_data, _headers, timeout)
|
||
elif isinstance(file_obj, io.BytesIO):
|
||
files_payload = {form_field: (file_name, file_obj.getvalue(), 'application/x-zip-compressed')}
|
||
return self._perform_upload(full_url, files_payload, extra_data, _headers, timeout)
|
||
elif hasattr(file_obj, 'read'):
|
||
files_payload = {form_field: (file_name, file_obj, 'application/x-zip-compressed')}
|
||
return self._perform_upload(full_url, files_payload, extra_data, _headers, timeout)
|
||
else:
|
||
self.logger.error(f"[ERROR][APIClient.upload_file][FAILURE] Unsupported file_obj type: {type(file_obj)}")
|
||
raise TypeError(f"Неподдерживаемый тип 'file_obj': {type(file_obj)}")
|
||
|
||
def _perform_upload(self, url, files, data, headers, timeout):
|
||
self.logger.debug(f"[DEBUG][APIClient._perform_upload][ENTER] Performing upload to {url}")
|
||
try:
|
||
response = self.session.post(
|
||
url=url,
|
||
files=files,
|
||
data=data or {},
|
||
headers=headers,
|
||
timeout=timeout or self.request_settings.get("timeout")
|
||
)
|
||
response.raise_for_status()
|
||
self.logger.info(f"[INFO][APIClient._perform_upload][SUCCESS] Upload successful to {url}")
|
||
return response.json()
|
||
except requests.exceptions.HTTPError as e:
|
||
self.logger.error(f"[ERROR][APIClient._perform_upload][FAILURE] HTTP error during upload: {e}")
|
||
raise SupersetAPIError(f"Ошибка API при загрузке: {e.response.text}") from e
|
||
except requests.exceptions.RequestException as e:
|
||
self.logger.error(f"[ERROR][APIClient._perform_upload][FAILURE] Network error during upload: {e}")
|
||
raise NetworkError(f"Ошибка сети при загрузке: {e}", url=url) from e
|
||
|
||
def fetch_paginated_count(
|
||
self,
|
||
endpoint: str,
|
||
query_params: Dict,
|
||
count_field: str = "count",
|
||
timeout: Optional[int] = None
|
||
) -> int:
|
||
self.logger.debug(f"[DEBUG][APIClient.fetch_paginated_count][ENTER] Fetching paginated count for {endpoint}")
|
||
response_json = self.request(
|
||
method="GET",
|
||
endpoint=endpoint,
|
||
params={"q": json.dumps(query_params)},
|
||
timeout=timeout or self.request_settings.get("timeout")
|
||
)
|
||
count = response_json.get(count_field, 0)
|
||
self.logger.debug(f"[DEBUG][APIClient.fetch_paginated_count][SUCCESS] Fetched paginated count: {count}")
|
||
return count
|
||
|
||
def fetch_paginated_data(
|
||
self,
|
||
endpoint: str,
|
||
pagination_options: Dict[str, Any],
|
||
timeout: Optional[int] = None
|
||
) -> List[Any]:
|
||
self.logger.debug(f"[DEBUG][APIClient.fetch_paginated_data][ENTER] Fetching paginated data for {endpoint}")
|
||
base_query = pagination_options.get("base_query", {})
|
||
total_count = pagination_options.get("total_count", 0)
|
||
results_field = pagination_options.get("results_field", "result")
|
||
page_size = base_query.get('page_size')
|
||
if not page_size or page_size <= 0:
|
||
raise ValueError("'page_size' должен быть положительным числом.")
|
||
total_pages = (total_count + page_size - 1) // page_size
|
||
results = []
|
||
for page in range(total_pages):
|
||
query = {**base_query, 'page': page}
|
||
response_json = self.request(
|
||
method="GET",
|
||
endpoint=endpoint,
|
||
params={"q": json.dumps(query)},
|
||
timeout=timeout or self.request_settings.get("timeout")
|
||
)
|
||
page_results = response_json.get(results_field, [])
|
||
results.extend(page_results)
|
||
self.logger.debug(f"[DEBUG][APIClient.fetch_paginated_data][SUCCESS] Fetched paginated data. Total items: {len(results)}")
|
||
return results |