Coverage for src/dataknobs_llm/fsm_integration/workflows.py: 31%
320 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""LLM workflow pattern implementation.
3This module provides pre-configured FSM patterns for LLM-based workflows,
4including RAG pipelines, chain-of-thought reasoning, and multi-agent systems.
6Note: This module was migrated from dataknobs_fsm.patterns.llm_workflow to
7consolidate all LLM functionality in the dataknobs-llm package.
8"""
10from typing import Any, Dict, List, Union, Callable
11from dataclasses import dataclass
12from enum import Enum
13import asyncio
15from dataknobs_fsm.api.simple import SimpleFSM
16from dataknobs_fsm.core.data_modes import DataHandlingMode
17from dataknobs_llm.llm.base import LLMConfig, LLMMessage, LLMResponse
18from dataknobs_llm.llm.providers import create_llm_provider
19from dataknobs_llm.llm.utils import (
20 MessageTemplate, MessageBuilder, ResponseParser
21)
24class WorkflowType(Enum):
25 """LLM workflow types."""
26 SIMPLE = "simple" # Single LLM call
27 CHAIN = "chain" # Sequential chain of LLM calls
28 RAG = "rag" # Retrieval-augmented generation
29 COT = "cot" # Chain-of-thought reasoning
30 TREE = "tree" # Tree-of-thought reasoning
31 AGENT = "agent" # Agent with tools
32 MULTI_AGENT = "multi_agent" # Multiple cooperating agents
35@dataclass
36class LLMStep:
37 """Single step in LLM workflow."""
38 name: str
39 prompt_template: MessageTemplate
40 model_config: LLMConfig | None = None # Override default
42 # Processing
43 pre_processor: Callable[[Any], Any] | None = None
44 post_processor: Callable[[LLMResponse], Any] | None = None
46 # Validation
47 validator: Callable[[Any], bool] | None = None
48 retry_on_failure: bool = True
49 max_retries: int = 3
51 # Dependencies
52 depends_on: List[str] | None = None
53 pass_context: bool = True # Pass previous results
55 # Output
56 output_key: str | None = None # Key in results dict
57 parse_json: bool = False
58 extract_code: bool = False
61@dataclass
62class RAGConfig:
63 """Configuration for RAG (Retrieval-Augmented Generation)."""
64 retriever_type: str # 'vector', 'keyword', 'hybrid'
65 index_path: str | None = None
66 embedding_model: str | None = None
68 # Retrieval settings
69 top_k: int = 5
70 similarity_threshold: float = 0.7
71 rerank: bool = False
72 rerank_model: str | None = None
74 # Context settings
75 max_context_length: int = 2000
76 context_template: MessageTemplate | None = None
78 # Chunking settings
79 chunk_size: int = 500
80 chunk_overlap: int = 50
83@dataclass
84class AgentConfig:
85 """Configuration for agent-based workflows."""
86 agent_name: str
87 role: str
88 capabilities: List[str]
90 # Tools
91 tools: List[Dict[str, Any]] | None = None
92 tool_descriptions: str | None = None
94 # Memory
95 memory_type: str | None = None # 'buffer', 'summary', 'vector'
96 memory_size: int = 10
98 # Planning
99 planning_enabled: bool = False
100 planning_steps: int = 5
102 # Reflection
103 reflection_enabled: bool = False
104 reflection_prompt: MessageTemplate | None = None
107@dataclass
108class LLMWorkflowConfig:
109 """Configuration for LLM workflow."""
110 workflow_type: WorkflowType
111 steps: List[LLMStep]
112 default_model_config: LLMConfig
114 # Workflow settings
115 max_iterations: int = 10
116 early_stop_condition: Callable[[Dict[str, Any]], bool] | None = None
118 # RAG settings (if applicable)
119 rag_config: RAGConfig | None = None
121 # Agent settings (if applicable)
122 agent_configs: List[AgentConfig] | None = None
124 # Memory and context
125 maintain_history: bool = True
126 max_history_length: int = 20
127 context_window: int = 4000
129 # Output settings
130 aggregate_outputs: bool = False
131 output_formatter: Callable[[Dict[str, Any]], Any] | None = None
133 # Error handling
134 error_handler: Callable[[Exception, str], Any] | None = None
135 fallback_response: str | None = None
137 # Monitoring
138 log_prompts: bool = False
139 log_responses: bool = False
140 track_tokens: bool = True
141 track_cost: bool = False
144class VectorRetriever:
145 """Simple vector-based retriever for RAG."""
147 def __init__(self, config: RAGConfig):
148 self.config = config
149 self.documents = []
150 self.embeddings = []
152 async def index_documents(self, documents: List[str]) -> None:
153 """Index documents for retrieval.
155 Generates embeddings for documents using the configured LLM provider.
156 In production, these would be stored in a vector database.
158 Args:
159 documents: List of documents to index
160 """
161 from dataknobs_fsm.llm.providers import get_provider
163 self.documents = documents
165 # Try to use a real embedding provider if available
166 if self.config.provider_config:
167 try:
168 provider = get_provider(self.config.provider_config)
170 # Check if provider supports embeddings
171 if hasattr(provider, 'embed'):
172 # Generate embeddings for all documents
173 self.embeddings = await provider.embed(documents)
175 # Normalize embeddings for cosine similarity
176 self.embeddings = [
177 self._normalize_embedding(emb) for emb in self.embeddings
178 ]
179 else:
180 # Fallback to mock embeddings if provider doesn't support them
181 self.embeddings = self._generate_mock_embeddings(documents)
182 except Exception as e:
183 # Log error and fallback to mock embeddings
184 import logging
185 logger = logging.getLogger(__name__)
186 logger.warning(f"Failed to generate real embeddings: {e}. Using mock embeddings.")
187 self.embeddings = self._generate_mock_embeddings(documents)
188 else:
189 # No provider configured, use mock embeddings
190 self.embeddings = self._generate_mock_embeddings(documents)
192 def _normalize_embedding(self, embedding: List[float]) -> List[float]:
193 """Normalize an embedding vector for cosine similarity.
195 Args:
196 embedding: Raw embedding vector
198 Returns:
199 Normalized embedding vector
200 """
201 import math
203 norm = math.sqrt(sum(x * x for x in embedding))
204 if norm == 0:
205 return embedding
206 return [x / norm for x in embedding]
208 def _generate_mock_embeddings(self, documents: List[str]) -> List[List[float]]:
209 """Generate mock embeddings for testing.
211 Args:
212 documents: Documents to generate embeddings for
214 Returns:
215 Mock embedding vectors
216 """
217 import hashlib
219 embeddings = []
220 for doc in documents:
221 # Create deterministic mock embedding based on document content
222 doc_hash = hashlib.sha256(doc.encode()).digest()
223 # Convert hash to 768-dimensional embedding (standard size)
224 embedding = []
225 for i in range(96): # 768 / 8 = 96
226 # Take 8 bytes at a time and convert to float
227 if i * 8 < len(doc_hash):
228 byte_chunk = doc_hash[i*8:(i+1)*8]
229 value = sum(b for b in byte_chunk) / 2040.0 # Normalize to ~[0, 1]
230 else:
231 # Pad with deterministic values if needed
232 value = (i % 10) / 10.0
234 # Expand to 8 dimensions
235 for j in range(8):
236 embedding.append(value * (1 + j * 0.1))
238 embeddings.append(self._normalize_embedding(embedding))
240 return embeddings
242 async def retrieve(self, query: str, top_k: int = None) -> List[str]:
243 """Retrieve relevant documents using semantic similarity.
245 Args:
246 query: Query string
247 top_k: Number of documents to retrieve
249 Returns:
250 List of most relevant documents
251 """
252 from dataknobs_fsm.llm.providers import get_provider
254 top_k = top_k or self.config.top_k
256 if not self.documents:
257 return []
259 # Generate embedding for query
260 query_embedding = None
262 if self.config.provider_config:
263 try:
264 provider = get_provider(self.config.provider_config)
265 if hasattr(provider, 'embed'):
266 query_embedding = await provider.embed(query)
267 query_embedding = self._normalize_embedding(query_embedding)
268 except Exception:
269 pass
271 if query_embedding is None:
272 # Fallback to mock embedding
273 query_embedding = self._generate_mock_embeddings([query])[0]
275 # Calculate cosine similarities
276 similarities = []
277 for i, doc_embedding in enumerate(self.embeddings):
278 similarity = self._cosine_similarity(query_embedding, doc_embedding)
279 similarities.append((similarity, i))
281 # Sort by similarity and return top-k documents
282 similarities.sort(reverse=True)
283 top_indices = [idx for _, idx in similarities[:top_k]]
285 return [self.documents[idx] for idx in top_indices]
287 def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
288 """Calculate cosine similarity between two vectors.
290 Args:
291 vec1: First vector
292 vec2: Second vector
294 Returns:
295 Cosine similarity score
296 """
297 if len(vec1) != len(vec2):
298 # Handle dimension mismatch by padding or truncating
299 min_len = min(len(vec1), len(vec2))
300 vec1 = vec1[:min_len]
301 vec2 = vec2[:min_len]
303 dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False))
304 return dot_product # Already normalized
307class LLMWorkflow:
308 """LLM workflow orchestrator using FSM pattern."""
310 def __init__(self, config: LLMWorkflowConfig):
311 """Initialize LLM workflow.
313 Args:
314 config: Workflow configuration
315 """
316 self.config = config
317 self._fsm = self._build_fsm()
318 self._providers = {}
319 self._history = []
320 self._context = {}
321 self._retriever = None
323 # Initialize retriever if RAG
324 if config.workflow_type == WorkflowType.RAG and config.rag_config:
325 self._retriever = VectorRetriever(config.rag_config)
327 def _build_fsm(self) -> SimpleFSM:
328 """Build FSM for LLM workflow."""
329 # Add start state
330 states = [{'name': 'start', 'type': 'initial', 'is_start': True}]
331 arcs = []
333 if self.config.workflow_type == WorkflowType.SIMPLE:
334 # Single LLM call
335 states.append({'name': 'llm_call', 'type': 'task'})
336 arcs.append({'from': 'start', 'to': 'llm_call', 'name': 'init'})
337 arcs.append({'from': 'llm_call', 'to': 'end', 'name': 'complete'})
339 elif self.config.workflow_type == WorkflowType.CHAIN:
340 # Sequential chain
341 for i, step in enumerate(self.config.steps):
342 state_name = f"step_{step.name}"
343 states.append({'name': state_name, 'type': 'task'})
345 if i == 0:
346 arcs.append({'from': 'start', 'to': state_name, 'name': f'init_{step.name}'})
347 else:
348 prev_state = f"step_{self.config.steps[i-1].name}"
349 arcs.append({
350 'from': prev_state,
351 'to': state_name,
352 'name': f'{self.config.steps[i-1].name}_to_{step.name}'
353 })
355 if i == len(self.config.steps) - 1:
356 arcs.append({'from': state_name, 'to': 'end', 'name': f'{step.name}_complete'})
358 elif self.config.workflow_type == WorkflowType.RAG:
359 # RAG pipeline
360 states.extend([
361 {'name': 'retrieve', 'type': 'task'},
362 {'name': 'augment', 'type': 'task'},
363 {'name': 'generate', 'type': 'task'}
364 ])
366 arcs.extend([
367 {'from': 'start', 'to': 'retrieve', 'name': 'init_retrieval'},
368 {'from': 'retrieve', 'to': 'augment', 'name': 'retrieve_to_augment'},
369 {'from': 'augment', 'to': 'generate', 'name': 'augment_to_generate'},
370 {'from': 'generate', 'to': 'end', 'name': 'generation_complete'}
371 ])
373 elif self.config.workflow_type == WorkflowType.COT:
374 # Chain-of-thought reasoning
375 states.extend([
376 {'name': 'decompose', 'type': 'task'},
377 {'name': 'reason', 'type': 'task'},
378 {'name': 'synthesize', 'type': 'task'}
379 ])
381 arcs.extend([
382 {'from': 'start', 'to': 'decompose', 'name': 'init_decompose'},
383 {'from': 'decompose', 'to': 'reason', 'name': 'decompose_to_reason'},
384 {'from': 'reason', 'to': 'synthesize', 'name': 'reason_to_synthesize'},
385 {'from': 'synthesize', 'to': 'end', 'name': 'synthesis_complete'}
386 ])
388 # Add end state
389 states.append({
390 'name': 'end',
391 'type': 'terminal'
392 })
394 # Build FSM configuration
395 fsm_config = {
396 'name': 'LLM_Workflow',
397 'data_mode': DataHandlingMode.REFERENCE.value,
398 'states': states,
399 'arcs': arcs,
400 'resources': []
401 }
403 return SimpleFSM(fsm_config)
405 async def _get_provider(self, step: LLMStep | None = None):
406 """Get LLM provider for step."""
407 config = step.model_config if step and step.model_config else self.config.default_model_config
409 key = f"{config.provider}_{config.model}"
410 if key not in self._providers:
411 self._providers[key] = create_llm_provider(config, is_async=True)
412 await self._providers[key].initialize()
414 return self._providers[key]
416 async def _execute_step(
417 self,
418 step: LLMStep,
419 input_data: Dict[str, Any]
420 ) -> Any:
421 """Execute a single workflow step.
423 Args:
424 step: Workflow step
425 input_data: Input data with template variables
427 Returns:
428 Step output
429 """
430 # Pre-process input
431 if step.pre_processor:
432 input_data = step.pre_processor(input_data)
434 # Format prompt
435 prompt = step.prompt_template.format(**input_data)
437 # Build messages
438 builder = MessageBuilder()
439 if self.config.default_model_config.system_prompt:
440 builder.system(self.config.default_model_config.system_prompt)
442 # Add history if maintaining
443 if self.config.maintain_history and self._history:
444 for msg in self._history[-self.config.max_history_length:]:
445 builder.messages.append(msg)
447 builder.user(prompt)
448 messages = builder.build()
450 # Get provider and generate
451 provider = await self._get_provider(step)
453 retry_count = 0
454 while retry_count <= step.max_retries:
455 try:
456 # Generate response
457 if self.config.default_model_config.stream:
458 response_text = ""
459 async for chunk in provider.stream_complete(messages):
460 response_text += chunk.delta
461 if self.config.default_model_config.stream_callback:
462 self.config.default_model_config.stream_callback(chunk)
463 response = LLMResponse(content=response_text, model=provider.config.model)
464 else:
465 response = await provider.complete(messages)
467 # Validate response
468 if step.validator and not step.validator(response):
469 if not step.retry_on_failure or retry_count >= step.max_retries:
470 raise ValueError(f"Validation failed for step {step.name}")
471 retry_count += 1
472 continue
474 # Parse response if needed
475 result = response.content
476 if step.parse_json:
477 result = ResponseParser.extract_json(response)
478 elif step.extract_code:
479 result = ResponseParser.extract_code(response)
481 # Post-process
482 if step.post_processor:
483 result = step.post_processor(result) # type: ignore
485 # Update history
486 if self.config.maintain_history:
487 self._history.append(LLMMessage(role='user', content=prompt))
488 self._history.append(LLMMessage(role='assistant', content=response.content))
490 # Track tokens and cost
491 if self.config.track_tokens and response.usage:
492 self._context['total_tokens'] = self._context.get('total_tokens', 0) + response.usage.get('total_tokens', 0)
494 return result
496 except Exception as e:
497 if retry_count >= step.max_retries:
498 if self.config.error_handler:
499 return self.config.error_handler(e, step.name)
500 raise
501 retry_count += 1
502 await asyncio.sleep(1.0 * retry_count) # Exponential backoff
504 async def _execute_rag(self, query: str) -> str:
505 """Execute RAG workflow.
507 Args:
508 query: User query
510 Returns:
511 Generated response
512 """
513 if not self._retriever:
514 raise ValueError("RAG configuration not provided")
516 # Retrieve relevant documents
517 documents = await self._retriever.retrieve(query)
519 # Build augmented prompt
520 context = "\n\n".join(documents)
521 if self.config.rag_config.context_template:
522 augmented_prompt = self.config.rag_config.context_template.format(
523 context=context,
524 query=query
525 )
526 else:
527 augmented_prompt = f"""Context:
528{context}
530Question: {query}
532Answer based on the context provided:"""
534 # Generate response
535 provider = await self._get_provider()
536 response = await provider.complete(augmented_prompt)
538 return response.content
540 async def _execute_cot(self, problem: str) -> str:
541 """Execute chain-of-thought reasoning.
543 Args:
544 problem: Problem to solve
546 Returns:
547 Solution
548 """
549 provider = await self._get_provider()
551 # Step 1: Decompose problem
552 decompose_prompt = f"""Break down this problem into smaller steps:
553{problem}
555List the steps needed to solve this:"""
557 decompose_response = await provider.complete(decompose_prompt)
558 steps = ResponseParser.extract_list(decompose_response)
560 # Step 2: Reason through each step
561 reasoning = []
562 for i, step in enumerate(steps, 1):
563 reason_prompt = f"""Problem: {problem}
564Step {i}: {step}
566Explain how to complete this step:"""
568 reason_response = await provider.complete(reason_prompt)
569 reasoning.append(f"Step {i}: {step}\n{reason_response.content}")
571 # Step 3: Synthesize solution
572 synthesis_prompt = f"""Problem: {problem}
574Reasoning:
575{chr(10).join(reasoning)}
577Based on the reasoning above, provide the final solution:"""
579 synthesis_response = await provider.complete(synthesis_prompt)
581 return synthesis_response.content
583 async def execute(
584 self,
585 input_data: Union[str, Dict[str, Any]]
586 ) -> Dict[str, Any]:
587 """Execute LLM workflow.
589 Args:
590 input_data: Input data or query
592 Returns:
593 Workflow results
594 """
595 # Normalize input
596 if isinstance(input_data, str):
597 input_data = {'query': input_data}
599 results = {}
601 if self.config.workflow_type == WorkflowType.SIMPLE:
602 # Single step execution
603 if self.config.steps:
604 output = await self._execute_step(self.config.steps[0], input_data)
605 results[self.config.steps[0].output_key or 'output'] = output
606 else:
607 # Direct LLM call
608 provider = await self._get_provider()
609 response = await provider.complete(input_data.get('query', ''))
610 results['output'] = response.content
612 elif self.config.workflow_type == WorkflowType.CHAIN:
613 # Sequential chain execution
614 current_context = input_data.copy()
616 for step in self.config.steps:
617 # Add dependencies to context
618 if step.depends_on:
619 for dep in step.depends_on:
620 if dep in results:
621 current_context[dep] = results[dep]
623 # Execute step
624 output = await self._execute_step(step, current_context)
626 # Store result
627 output_key = step.output_key or step.name
628 results[output_key] = output
630 # Update context if passing
631 if step.pass_context:
632 current_context[output_key] = output
634 elif self.config.workflow_type == WorkflowType.RAG:
635 # RAG pipeline
636 output = await self._execute_rag(input_data.get('query', ''))
637 results['output'] = output
639 elif self.config.workflow_type == WorkflowType.COT:
640 # Chain-of-thought
641 output = await self._execute_cot(input_data.get('problem', input_data.get('query', '')))
642 results['output'] = output
644 # Format output if configured
645 if self.config.output_formatter:
646 results = self.config.output_formatter(results)
648 # Add metadata
649 if self.config.track_tokens:
650 results['_tokens'] = self._context.get('total_tokens', 0)
652 return results
654 async def index_documents(self, documents: List[str]) -> None:
655 """Index documents for RAG.
657 Args:
658 documents: Documents to index
659 """
660 if not self._retriever:
661 raise ValueError("RAG configuration not provided")
662 await self._retriever.index_documents(documents)
664 async def close(self) -> None:
665 """Close all providers."""
666 for provider in self._providers.values():
667 await provider.close()
670def create_simple_llm_workflow(
671 prompt_template: str,
672 model: str = 'gpt-3.5-turbo',
673 provider: str = 'openai',
674 **kwargs
675) -> LLMWorkflow:
676 """Create simple LLM workflow.
678 Args:
679 prompt_template: Prompt template string
680 model: Model name
681 provider: Provider name
682 **kwargs: Additional configuration
684 Returns:
685 Configured LLM workflow
686 """
687 template = MessageTemplate(prompt_template)
689 config = LLMWorkflowConfig(
690 workflow_type=WorkflowType.SIMPLE,
691 steps=[
692 LLMStep(
693 name='generate',
694 prompt_template=template
695 )
696 ],
697 default_model_config=LLMConfig(
698 provider=provider,
699 model=model,
700 **kwargs
701 )
702 )
704 return LLMWorkflow(config)
707def create_rag_workflow(
708 model: str = 'gpt-3.5-turbo',
709 provider: str = 'openai',
710 retriever_type: str = 'vector',
711 top_k: int = 5,
712 **kwargs
713) -> LLMWorkflow:
714 """Create RAG workflow.
716 Args:
717 model: Model name
718 provider: Provider name
719 retriever_type: Type of retriever
720 top_k: Number of documents to retrieve
721 **kwargs: Additional configuration
723 Returns:
724 Configured RAG workflow
725 """
726 config = LLMWorkflowConfig(
727 workflow_type=WorkflowType.RAG,
728 steps=[],
729 default_model_config=LLMConfig(
730 provider=provider,
731 model=model,
732 **kwargs
733 ),
734 rag_config=RAGConfig(
735 retriever_type=retriever_type,
736 top_k=top_k
737 )
738 )
740 return LLMWorkflow(config)
743def create_chain_workflow(
744 steps: List[Dict[str, Any]],
745 model: str = 'gpt-3.5-turbo',
746 provider: str = 'openai',
747 **kwargs
748) -> LLMWorkflow:
749 """Create chain workflow.
751 Args:
752 steps: List of step configurations
753 model: Model name
754 provider: Provider name
755 **kwargs: Additional configuration
757 Returns:
758 Configured chain workflow
759 """
760 llm_steps = []
761 for step_config in steps:
762 llm_steps.append(LLMStep(
763 name=step_config['name'],
764 prompt_template=MessageTemplate(step_config['prompt']),
765 output_key=step_config.get('output_key'),
766 parse_json=step_config.get('parse_json', False),
767 depends_on=step_config.get('depends_on')
768 ))
770 config = LLMWorkflowConfig(
771 workflow_type=WorkflowType.CHAIN,
772 steps=llm_steps,
773 default_model_config=LLMConfig(
774 provider=provider,
775 model=model,
776 **kwargs
777 )
778 )
780 return LLMWorkflow(config)