-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Keras notifications
- Loading branch information
Showing
12 changed files
with
527 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
notifyker/__pycache__/ | ||
|
||
notifyker/callbackNK/__pycache__/ | ||
|
||
notifyker/notifiers/__pycache__/ | ||
|
||
notifyker/config\.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,18 @@ | ||
NotifyKer | ||
# NotifyKer | ||
|
||
*Callback notifier and manager bot for Keras ML library* | ||
|
||
##### Simple to use: | ||
|
||
Set your TOKEN and PROXY settings in **notifyker/config_default.py** and rename to **notifyker/config.py** | ||
|
||
```python | ||
from notifyker import NotifierTelegram, CallbackSimple | ||
|
||
|
||
nfk = NotifierTelegram() | ||
callback = CallbackSimple(notifier=nfk) | ||
|
||
model.fit(... | ||
callbacks=[callback]) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .notifiers import NotifierTelegram | ||
from .callbackNK import CallbackSimple | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .callback_simple import CallbackSimple |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from keras.callbacks import Callback | ||
|
||
|
||
class CallbackBase(Callback): | ||
def __init__(self, notifier=None, custom_metrics=None): | ||
super().__init__() | ||
|
||
if notifier is not None: | ||
self.notifier = notifier | ||
|
||
if custom_metrics is not None: | ||
self.custom_metrics = custom_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import time | ||
|
||
from .callback_base import CallbackBase | ||
|
||
|
||
class CallbackSimple(CallbackBase): | ||
def __init__(self, verbose=0, notifier=None, custom_metrics=None): | ||
super().__init__() | ||
|
||
if notifier is not None: | ||
self.notifier = notifier | ||
else: | ||
raise ValueError | ||
|
||
if custom_metrics is not None: | ||
self.custom_metrics = custom_metrics | ||
|
||
self.details = {} | ||
self.starting_time = None | ||
self.batch_update_freq = None | ||
self.current_epoch = 1 | ||
|
||
def on_train_begin(self, logs=None): | ||
self.notifier._connect() | ||
self.notifier.flags_batch = [] | ||
self.notifier.flags_epoch = [] | ||
|
||
for i in self.params: | ||
self.details[i] = self.params[i] | ||
|
||
self.starting_time = time.time() | ||
self.batch_update_freq = max(self.details['samples'] // self.details['batch_size'] // 10, 1) | ||
|
||
message = [] | ||
message.append('\nTraining started in {}\n'.format(self.starting_time)) | ||
|
||
if self.notifier.verbose_value == 0: | ||
self.notifier.message(' \n'.join(message)) | ||
|
||
return | ||
|
||
if self.notifier.verbose_value == 2: | ||
message.append('With the following parameters:') | ||
for i in self.details: | ||
if isinstance(self.details[i], list): | ||
value = ', '.join(self.details[i]) | ||
else: | ||
value = str(self.details[i]) | ||
|
||
message.append('{0:25s}: {1:25s}'.format(i, value)) | ||
|
||
self.notifier.message(' \n'.join(message)) | ||
|
||
def on_train_end(self, logs=None): | ||
if 's' in self.notifier.flags_batch: | ||
tag = 'forcibly' | ||
else: | ||
tag = 'Successfully' | ||
self.notifier.message('Training completed {}'.format(tag)) | ||
self.notifier._close_connect() | ||
|
||
def on_batch_end(self, batch, logs=None): | ||
if self.notifier.flags_batch: | ||
self.flags_handler() | ||
|
||
if self.notifier.verbose_value == 0: | ||
return | ||
|
||
if batch % self.batch_update_freq == 0: | ||
message = [] | ||
|
||
message.append('Epoch {} / {}'.format(self.current_epoch, self.details['epochs'])) | ||
|
||
pad_bar = '[{}{}]'.format('++' * (batch // self.batch_update_freq), '==' * (10 - batch // self.batch_update_freq)) | ||
message.append('{} / {} {}'.format(2 * logs['batch'], self.details['samples'], pad_bar)) | ||
|
||
if self.notifier.verbose_value == 2: | ||
for i in self.details['metrics']: | ||
if 'val_' not in i: | ||
message.append('{:15s}: {:15s}'.format(i, str(logs[i]))) | ||
|
||
ack = self.notifier.message(' \n'.join(message), self.notifier.cache_message_id) | ||
self.notifier.cache_message_id = ack.message_id | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
if self.notifier.flags_epoch: | ||
self.flags_handler() | ||
|
||
self.notifier.cache_message_id = None | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
if self.notifier.flags_epoch: | ||
self.flags_handler() | ||
|
||
message = [] | ||
|
||
message.append('Epoch {} / {}'.format(self.current_epoch, self.details['epochs'])) | ||
|
||
pad_bar = '[{}]'.format('++' * 10) | ||
message.append('{} / {} {}'.format(self.details['samples'], self.details['samples'], pad_bar)) | ||
for i in logs: | ||
message.append('{:15s}: {:15s}'.format(i, str(logs[i]))) | ||
|
||
self.notifier._status = ' \n'.join(message) | ||
|
||
if self.notifier.verbose_value != 0: | ||
ack = self.notifier.message(' \n'.join(message), self.notifier.cache_message_id) | ||
self.notifier.cache_message_id = ack.message_id | ||
|
||
self.current_epoch += 1 | ||
|
||
def flags_handler(self): | ||
if 'p' in self.notifier.flags_batch: | ||
self.notifier.flags_batch.remove('p') | ||
|
||
while 'c' not in self.notifier.flags_batch and 's' not in self.notifier.flags_batch: | ||
time.sleep(10) | ||
|
||
if 's' in self.notifier.flags_batch: | ||
self.model.stop_training = True | ||
|
||
if 'c' in self.notifier.flags_batch: | ||
self.notifier.flags_batch.remove('c') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
TOKEN = 'XXXX:YYYY' | ||
PROXY = {'proxy_url': 'socks5h://ip:port', 'urllib3_proxy_kwargs': {'username': 'username', 'password': 'password'}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .notifier_telegram_menu import NotifierTelegram | ||
from .notifier_telegram import NotifierTelegram as NotifierTelegramSimple |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
class NotifierBase: | ||
""" | ||
Abstract class for notifiers | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize mandatory variables of notifier | ||
""" | ||
self.cache_message_id = None | ||
self.flags_batch = [] | ||
self.flags_epoch = [] | ||
self._status = None | ||
|
||
def status(self): | ||
""" | ||
Status message update | ||
""" | ||
if self._status is None: | ||
text = 'Status undefined. Probably, first epoch is still performing' | ||
else: | ||
text = self._status | ||
|
||
self.message(text) | ||
|
||
def message(self, message, message_id=None): | ||
""" | ||
Abstract method of message sending | ||
Method must be redefined with the return variable ack (can be None, used to edit message of batches) | ||
""" | ||
pass | ||
|
||
def _close_connect(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from telegram.ext import CommandHandler, Updater | ||
|
||
from ..config import TOKEN, PROXY | ||
from .notifier_base import NotifierBase | ||
|
||
|
||
class NotifierTelegram(NotifierBase): | ||
""" | ||
Telegram notifier bot | ||
""" | ||
def __init__(self): | ||
""" | ||
Create handlers and chat id for message edits | ||
""" | ||
super().__init__() | ||
self.active = False | ||
self.cache_message_id = None | ||
self.flags_batch = [] | ||
self.flags_epoch = [] | ||
self._status = None | ||
self.verbose_value = 1 | ||
self.chat_id = None | ||
|
||
self._connect() | ||
|
||
def _connect(self): | ||
if not self.active: | ||
self.updater = Updater(TOKEN, request_kwargs=PROXY) | ||
|
||
self.handlers() | ||
self.updater.start_polling() | ||
|
||
self.active = True | ||
|
||
def message(self, message, message_id=None, reply_markup=None): | ||
""" | ||
Telegram specific method of message sending | ||
""" | ||
if message_id is not None: | ||
ack = self.updater.bot.edit_message_text(chat_id=self.chat_id, text=message, message_id=message_id) | ||
else: | ||
ack = self.updater.bot.send_message(chat_id=self.chat_id, text=message, reply_markup=reply_markup) | ||
|
||
return ack | ||
|
||
def handlers(self): | ||
""" | ||
Method of activation of telegram bot handlers | ||
""" | ||
self.updater.dispatcher.add_handler(CommandHandler('start', self.start)) | ||
self.updater.dispatcher.add_handler(CommandHandler('interrupt', self.interrupt)) | ||
self.updater.dispatcher.add_handler(CommandHandler('help', self._help)) | ||
self.updater.dispatcher.add_handler(CommandHandler('status', self.status)) | ||
self.updater.dispatcher.add_handler(CommandHandler('pause', self.pause)) | ||
self.updater.dispatcher.add_handler(CommandHandler('verbose', self.verbose)) | ||
self.updater.dispatcher.add_handler(CommandHandler('continue', self.cont)) | ||
|
||
def start(self, bot, update): | ||
""" | ||
Method of start message processing required to obtain chat_id | ||
""" | ||
self.chat_id = update.message.chat_id | ||
update.message.reply_text('Hello, my friend') | ||
|
||
def _help(self, bot, update): | ||
""" | ||
Method of help command processing | ||
""" | ||
message = 'Welcome! Enter /start to add your chat_id before you start training\n\ | ||
/help - Show available commands\n\ | ||
/status - Show current training status - epoch, metrics\n\ | ||
/pause - Suspend training process (model still in a memory)\n\ | ||
/continue - Continue training process\n\ | ||
/interrupt - Interrupt training process ATTENTION: You will not be able to continue by this bot\n' | ||
self.message(message) | ||
|
||
def pause(self, bot, update): | ||
""" | ||
Method of pause command processing. Suspend the training process | ||
""" | ||
self.flags_batch.append('p') | ||
self.message('Training suspended. Use /stop or /cont now') | ||
|
||
def cont(self, bot, update): | ||
""" | ||
Method of continue command processing. Continue the training process | ||
""" | ||
self.flags_batch.append('c') | ||
self.message('Training continues') | ||
|
||
def verbose(self, bot, update): | ||
self.message('Current verbose: {}'.format(self.verbose_value)) | ||
|
||
def interrupt(self, bot, update): | ||
""" | ||
Method of stop (training) command processing | ||
""" | ||
self.flags_batch.append('s') | ||
self.message('') | ||
|
||
self.message('Training interrupting...') | ||
|
||
def _close_connect(self): | ||
self.updater.stop() | ||
self.active = False |
Oops, something went wrong.