Skip to content

Commit

Permalink
Fix define target + track utils
Browse files Browse the repository at this point in the history
  • Loading branch information
c3p0-upgini committed Jan 17, 2024
1 parent d906432 commit e427aa3
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 68 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def send_log(msg: str):


here = Path(__file__).parent.resolve()
version = "1.1.238"
version = "1.1.239"
try:
send_log(f"Start setup PyLib version {version}")
setup(
Expand Down
6 changes: 2 additions & 4 deletions src/upgini/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
logger: Optional[logging.Logger] = None,
client_ip: Optional[str] = None,
warning_counter: Optional[WarningCounter] = None,
**kwargs,
):
Expand Down Expand Up @@ -125,7 +124,6 @@ def __init__(
else:
self.logger = logging.getLogger()
self.logger.setLevel("FATAL")
self.client_ip = client_ip
self.warning_counter = warning_counter or WarningCounter()

def __len__(self):
Expand Down Expand Up @@ -1019,7 +1017,7 @@ def search(
task_type=self.task_type,
endpoint=self.endpoint,
api_key=self.api_key,
client_ip=self.client_ip,
logger=self.logger,
)

def validation(
Expand Down Expand Up @@ -1089,7 +1087,7 @@ def validation(
initial_search_task_id=initial_search_task_id,
endpoint=self.endpoint,
api_key=self.api_key,
client_ip=self.client_ip,
logger=self.logger,
)

def prepare_uploading_file(self, base_path: str) -> str:
Expand Down
26 changes: 13 additions & 13 deletions src/upgini/features_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,19 @@ def __init__(
exclude_columns: Optional[List[str]] = None,
baseline_score_column: Optional[Any] = None,
client_ip: Optional[str] = None,
client_visitorid: Optional[str] = None,
**kwargs,
):
self._api_key = api_key or os.environ.get(UPGINI_API_KEY)
if api_key is not None and not isinstance(api_key, str):
raise ValidationError(f"api_key should be `string`, but passed: `{api_key}`")
self.rest_client = get_rest_client(endpoint, self._api_key)
self.rest_client = get_rest_client(endpoint, self._api_key, client_ip, client_visitorid)
self.client_ip = client_ip
self.client_visitorid = client_visitorid

self.logs_enabled = logs_enabled
if logs_enabled:
self.logger = LoggerFactory().get_logger(endpoint, self._api_key, client_ip)
self.logger = LoggerFactory().get_logger(endpoint, self._api_key, client_ip, client_visitorid)
else:
self.logger = logging.getLogger()
self.logger.setLevel("FATAL")
Expand Down Expand Up @@ -231,7 +233,7 @@ def __init__(
self.feature_importances_ = []
self.search_id = search_id
if search_id:
search_task = SearchTask(search_id, endpoint=self.endpoint, api_key=self._api_key, client_ip=client_ip)
search_task = SearchTask(search_id, endpoint=self.endpoint, api_key=self._api_key, logger=self.logger)

print(bundle.get("search_by_task_id_start"))
trace_id = str(uuid.uuid4())
Expand Down Expand Up @@ -295,7 +297,7 @@ def _get_api_key(self):
def _set_api_key(self, api_key: str):
self._api_key = api_key
if self.logs_enabled:
self.logger = LoggerFactory().get_logger(self.endpoint, self._api_key, self.client_ip)
self.logger = LoggerFactory().get_logger(self.endpoint, self._api_key, self.client_ip, self.client_visitorid)

api_key = property(_get_api_key, _set_api_key)

Expand Down Expand Up @@ -678,7 +680,7 @@ def transform(
return None

if not metrics_calculation:
transform_usage = get_rest_client(self.endpoint, self.api_key).get_current_transform_usage(trace_id)
transform_usage = self.rest_client.get_current_transform_usage(trace_id)
self.logger.info(f"Current transform usage: {transform_usage}. Transforming {len(X)} rows")
if transform_usage.has_limit:
if len(X) > transform_usage.rest_rows:
Expand Down Expand Up @@ -1805,7 +1807,6 @@ def __inner_transform(
api_key=self.api_key, # type: ignore
date_format=self.date_format, # type: ignore
logger=self.logger,
client_ip=self.client_ip,
)
dataset.meaning_types = meaning_types
dataset.search_keys = combined_search_keys
Expand Down Expand Up @@ -1868,7 +1869,7 @@ def __inner_transform(
progress = self.get_progress(trace_id, validation_task)
except KeyboardInterrupt as e:
print(bundle.get("search_stopping"))
get_rest_client(self.endpoint, self.api_key).stop_search_task_v2(
self.rest_client.stop_search_task_v2(
trace_id, validation_task.search_task_id
)
self.logger.warning(f"Search {validation_task.search_task_id} stopped by user")
Expand Down Expand Up @@ -2140,7 +2141,6 @@ def __inner_fit(
date_format=self.date_format, # type: ignore
random_state=self.random_state, # type: ignore
logger=self.logger,
client_ip=self.client_ip,
)
dataset.meaning_types = meaning_types
dataset.search_keys = combined_search_keys
Expand Down Expand Up @@ -2197,7 +2197,7 @@ def __inner_fit(
progress = self.get_progress(trace_id)
except KeyboardInterrupt as e:
print(bundle.get("search_stopping"))
get_rest_client(self.endpoint, self.api_key).stop_search_task_v2(trace_id, self._search_task.search_task_id)
self.rest_client.stop_search_task_v2(trace_id, self._search_task.search_task_id)
self.logger.warning(f"Search {self._search_task.search_task_id} stopped by user")
print(bundle.get("search_stopped"))
raise e
Expand Down Expand Up @@ -3183,7 +3183,7 @@ def __show_report_button(self):
metrics_df=self.metrics,
autofe_descriptions_df=self.get_autofe_features_description(),
search_id=self._search_task.search_task_id,
email=get_rest_client(self.endpoint, self.api_key).get_current_email(),
email=self.rest_client.get_current_email(),
search_keys=[str(sk) for sk in self.search_keys.values()],
)
except Exception:
Expand Down Expand Up @@ -3367,21 +3367,21 @@ def sample(inp, sample_index):
pickle.dump(sample(eval_set[0][0], eval_xy_sample_index), eval_x_file)
with open(f"{tmp_dir}/eval_y.pickle", "wb") as eval_y_file:
pickle.dump(sample(eval_set[0][1], eval_xy_sample_index), eval_y_file)
get_rest_client(self.endpoint, self.api_key).dump_input_files(
self.rest_client.dump_input_files(
trace_id,
f"{tmp_dir}/x.pickle",
f"{tmp_dir}/y.pickle",
f"{tmp_dir}/eval_x.pickle",
f"{tmp_dir}/eval_y.pickle",
)
else:
get_rest_client(self.endpoint, self.api_key).dump_input_files(
self.rest_client.dump_input_files(
trace_id,
f"{tmp_dir}/x.pickle",
f"{tmp_dir}/y.pickle",
)
else:
get_rest_client(self.endpoint, self.api_key).dump_input_files(
self.rest_client.dump_input_files(
trace_id,
f"{tmp_dir}/x.pickle",
)
Expand Down
30 changes: 19 additions & 11 deletions src/upgini/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,13 @@ class _RestClient:
USER_AGENT_HEADER_VALUE = "pyupgini/" + __version__
SEARCH_KEYS_HEADER_NAME = "Search-Keys"

def __init__(self, service_endpoint, refresh_token, silent_mode=False):
def __init__(self, service_endpoint, refresh_token, silent_mode=False, client_ip=None, client_visitorid=None):
# debug_requests_on()
self._service_endpoint = service_endpoint
self._refresh_token = refresh_token
self.silent_mode = silent_mode
self.client_ip = client_ip
self.client_visitorid = client_visitorid
self._access_token = self._refresh_access_token()
# self._access_token: Optional[str] = None # self._refresh_access_token()
self.last_refresh_time = time.time()
Expand Down Expand Up @@ -470,7 +472,7 @@ def open_and_send():
)
files["tracking"] = (
"tracking.json",
dumps(get_track_metrics()).encode(),
dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
"application/json",
)
additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
Expand Down Expand Up @@ -554,7 +556,7 @@ def open_and_send():
)
files["tracking"] = (
"ide",
dumps(get_track_metrics()).encode(),
dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
"application/json",
)

Expand Down Expand Up @@ -662,7 +664,7 @@ def get_provider_search_metadata_v3(self, provider_search_task_id: str, trace_id
return ProviderTaskMetadataV2.parse_obj(response)

def get_current_transform_usage(self, trace_id) -> TransformUsage:
track_metrics = get_track_metrics()
track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
visitor_id = track_metrics.get("visitorId")
response = self._with_unauth_retry(
lambda: self._send_get_req(
Expand Down Expand Up @@ -905,35 +907,40 @@ def resolve_api_token(api_token: Optional[str]) -> str:
return DEMO_API_KEY


def get_rest_client(backend_url: Optional[str] = None, api_token: Optional[str] = None) -> _RestClient:
def get_rest_client(backend_url: Optional[str] = None, api_token: Optional[str] = None,
client_ip: Optional[str] = None, client_visitorid: Optional[str] = None) -> _RestClient:
url = _resolve_backend_url(backend_url)
token = resolve_api_token(api_token)

return _get_rest_client(url, token)
return _get_rest_client(url, token, client_ip, client_visitorid)


def is_demo_api_key(api_token: Optional[str]) -> bool:
return api_token is None or api_token == "" or api_token == DEMO_API_KEY


@lru_cache()
def _get_rest_client(backend_url: str, api_token: str) -> _RestClient:
def _get_rest_client(backend_url: str, api_token: str,
client_ip: Optional[str] = None, client_visitorid: Optional[str] = None) -> _RestClient:
return _RestClient(backend_url, api_token)


class BackendLogHandler(logging.Handler):
def __init__(self, rest_client: _RestClient, client_ip: Optional[str] = None, *args, **kwargs) -> None:
def __init__(self, rest_client: _RestClient,
client_ip: Optional[str] = None, client_visitorid: Optional[str] = None,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.rest_client = rest_client
self.track_metrics = None
self.hostname = "0.0.0.0"
self.client_ip = client_ip
self.client_visitorid = client_visitorid

def emit(self, record: logging.LogRecord) -> None:
def task():
try:
if self.track_metrics is None or len(self.track_metrics) == 0:
self.track_metrics = get_track_metrics(self.client_ip)
self.track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
self.hostname = self.track_metrics.get("ip") or "0.0.0.0"
text = self.format(record)
tags = self.track_metrics
Expand Down Expand Up @@ -975,7 +982,8 @@ def __init__(self, *args, **kwargs):
root.handlers.clear()

def get_logger(
self, backend_url: Optional[str] = None, api_token: Optional[str] = None, client_ip: Optional[str] = None
self, backend_url: Optional[str] = None, api_token: Optional[str] = None,
client_ip: Optional[str] = None, client_visitorid: Optional[str] = None
) -> logging.Logger:
url = _resolve_backend_url(backend_url)
token = resolve_api_token(api_token)
Expand All @@ -987,7 +995,7 @@ def get_logger(
upgini_logger = logging.getLogger(f"upgini.{hash(key)}")
upgini_logger.handlers.clear()
rest_client = get_rest_client(backend_url, api_token)
datadog_handler = BackendLogHandler(rest_client, client_ip)
datadog_handler = BackendLogHandler(rest_client, client_ip, client_visitorid)
json_formatter = jsonlogger.JsonFormatter(
"%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s",
timestamp=True,
Expand Down
9 changes: 7 additions & 2 deletions src/upgini/search_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import tempfile
import time
from functools import lru_cache
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
task_type: Optional[ModelTaskType] = None,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
client_ip: Optional[str] = None,
logger: Optional[logging.Logger] = None,
):
self.search_task_id = search_task_id
self.initial_search_task_id = initial_search_task_id
Expand All @@ -55,7 +56,11 @@ def __init__(
self.summary = None
self.endpoint = endpoint
self.api_key = api_key
self.logger = LoggerFactory().get_logger(endpoint, api_key, client_ip)
if logger is not None:
self.logger = logger
else:
self.logger = logging.getLogger()
self.logger.setLevel("FATAL")
self.provider_metadata_v2: Optional[List[ProviderTaskMetadataV2]] = None
self.unused_features_for_generation: Optional[List[str]] = None

Expand Down
15 changes: 9 additions & 6 deletions src/upgini/utils/target_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@ def define_task(y: pd.Series, logger: Optional[logging.Logger] = None, silent: b
target_items = target.nunique()
if target_items == 1:
raise ValidationError(bundle.get("dataset_constant_target"))
target_ratio = target_items / len(target)

if target_items == 2:
task = ModelTaskType.BINARY
elif (target.dtype.kind == "f" and np.any(target != target.astype(int))) or (
is_numeric_dtype(target) and (target_items > 50 or target_ratio > 0.2)
):
task = ModelTaskType.REGRESSION
else:
task = ModelTaskType.MULTICLASS
non_zero_target = target[target != 0]
target_ratio = target_items / len(non_zero_target)
if (target.dtype.kind == "f" and np.any(target != target.astype(int))) or (
is_numeric_dtype(target) and (target_items > 50 or target_ratio > 0.2)
):
task = ModelTaskType.REGRESSION
else:
task = ModelTaskType.MULTICLASS
logger.info(f"Detected task type: {task}")
if not silent:
print(bundle.get("target_type_detected").format(task))
Expand Down
Loading

0 comments on commit e427aa3

Please sign in to comment.