-
Notifications
You must be signed in to change notification settings - Fork 732
/
callbacks.py
181 lines (153 loc) · 5.93 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import Any, Dict, List, Optional
from chainlit.context import context_var
from chainlit.element import Text
from chainlit.step import Step, StepType
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from llama_index.callbacks import TokenCountingHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
DEFAULT_IGNORE = [
CBEventType.CHUNKING,
CBEventType.SYNTHESIZE,
CBEventType.EMBEDDING,
CBEventType.NODE_PARSING,
CBEventType.QUERY,
CBEventType.TREE,
]
class LlamaIndexCallbackHandler(TokenCountingHandler):
"""Base callback handler that can be used to track event starts and ends."""
steps: Dict[str, Step]
def __init__(
self,
event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
) -> None:
"""Initialize the base callback handler."""
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
self.context = context_var.get()
self.steps = {}
def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
if event_parent_id and event_parent_id in self.steps:
return event_parent_id
elif self.context.current_step:
return self.context.current_step.id
elif self.context.session.root_message:
return self.context.session.root_message.id
else:
return None
def _restore_context(self) -> None:
"""Restore Chainlit context in the current thread
Chainlit context is local to the main thread, and LlamaIndex
runs the callbacks in its own threads, so they don't have a
Chainlit context by default.
This method restores the context in which the callback handler
has been created (it's always created in the main thread), so
that we can actually send messages.
"""
context_var.set(self.context)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
self._restore_context()
step_type: StepType = "undefined"
if event_type == CBEventType.RETRIEVE:
step_type = "retrieval"
elif event_type == CBEventType.LLM:
step_type = "llm"
else:
return event_id
step = Step(
name=event_type.value,
type=step_type,
parent_id=self._get_parent_id(parent_id),
id=event_id,
disable_feedback=False,
)
self.steps[event_id] = step
step.start = utc_now()
step.input = payload or {}
self.context.loop.create_task(step.send())
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
step = self.steps.get(event_id, None)
if payload is None or step is None:
return
self._restore_context()
step.end = utc_now()
if event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
if sources:
source_refs = "\, ".join(
[f"Source {idx}" for idx, _ in enumerate(sources)]
)
step.elements = [
Text(
name=f"Source {idx}",
content=source.node.get_text() or "Empty node",
)
for idx, source in enumerate(sources)
]
step.output = f"Retrieved the following sources: {source_refs}"
self.context.loop.create_task(step.update())
if event_type == CBEventType.LLM:
formatted_messages = payload.get(
EventPayload.MESSAGES
) # type: Optional[List[ChatMessage]]
formatted_prompt = payload.get(EventPayload.PROMPT)
response = payload.get(EventPayload.RESPONSE)
if formatted_messages:
messages = [
GenerationMessage(
role=m.role.value, content=m.content or "" # type: ignore
)
for m in formatted_messages
]
else:
messages = None
if isinstance(response, ChatResponse):
content = response.message.content or ""
elif isinstance(response, CompletionResponse):
content = response.text
else:
content = ""
step.output = content
token_count = self.total_llm_token_count or None
if messages and isinstance(response, ChatResponse):
msg: ChatMessage = response.message
step.generation = ChatGeneration(
messages=messages,
message_completion=GenerationMessage(
role=msg.role.value, # type: ignore
content=content,
),
token_count=token_count,
)
elif formatted_prompt:
step.generation = CompletionGeneration(
prompt=formatted_prompt,
completion=content,
token_count=token_count,
)
self.context.loop.create_task(step.update())
self.steps.pop(event_id, None)
def _noop(self, *args, **kwargs):
pass
start_trace = _noop
end_trace = _noop