add ci workflows #1

Merged
godber merged 8 commits from add-ci into main 2025-08-13 21:03:42 -07:00
22 changed files with 965 additions and 528 deletions
Showing only changes of commit a1f533c6a8 - Show all commits

View File

@@ -1,3 +1,3 @@
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors.""" """EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@@ -8,32 +8,29 @@ from .ui.callbacks.interactions import InteractionCallbacks
def create_app(): def create_app():
app = dash.Dash( app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP]
)
layout_manager = AppLayout() layout_manager = AppLayout()
app.layout = layout_manager.create_layout() app.layout = layout_manager.create_layout()
DataProcessingCallbacks() DataProcessingCallbacks()
VisualizationCallbacks() VisualizationCallbacks()
InteractionCallbacks() InteractionCallbacks()
return app return app
def run_app(app=None, debug=None, host=None, port=None): def run_app(app=None, debug=None, host=None, port=None):
if app is None: if app is None:
app = create_app() app = create_app()
app.run( app.run(
debug=debug if debug is not None else AppSettings.DEBUG, debug=debug if debug is not None else AppSettings.DEBUG,
host=host if host is not None else AppSettings.HOST, host=host if host is not None else AppSettings.HOST,
port=port if port is not None else AppSettings.PORT port=port if port is not None else AppSettings.PORT,
) )
if __name__ == '__main__': if __name__ == "__main__":
app = create_app() app = create_app()
run_app(app) run_app(app)

View File

@@ -3,105 +3,100 @@ import os
class AppSettings: class AppSettings:
# UI Configuration # UI Configuration
UPLOAD_STYLE = { UPLOAD_STYLE = {
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px' "margin-bottom": "20px",
} }
PROMPTS_UPLOAD_STYLE = { PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"}
**UPLOAD_STYLE,
'borderColor': '#28a745' PLOT_CONFIG = {"responsive": True, "displayModeBar": True}
}
PLOT_STYLE = {"height": "85vh", "width": "100%"}
PLOT_CONFIG = {
'responsive': True,
'displayModeBar': True
}
PLOT_STYLE = {
'height': '85vh',
'width': '100%'
}
PLOT_LAYOUT_CONFIG = { PLOT_LAYOUT_CONFIG = {
'height': None, "height": None,
'autosize': True, "autosize": True,
'margin': dict(l=0, r=0, t=50, b=0) "margin": dict(l=0, r=0, t=50, b=0),
} }
# Dimensionality Reduction Settings # Dimensionality Reduction Settings
DEFAULT_N_COMPONENTS_3D = 3 DEFAULT_N_COMPONENTS_3D = 3
DEFAULT_N_COMPONENTS_2D = 2 DEFAULT_N_COMPONENTS_2D = 2
DEFAULT_RANDOM_STATE = 42 DEFAULT_RANDOM_STATE = 42
# Available Methods # Available Methods
REDUCTION_METHODS = [ REDUCTION_METHODS = [
{'label': 'PCA', 'value': 'pca'}, {"label": "PCA", "value": "pca"},
{'label': 't-SNE', 'value': 'tsne'}, {"label": "t-SNE", "value": "tsne"},
{'label': 'UMAP', 'value': 'umap'} {"label": "UMAP", "value": "umap"},
] ]
COLOR_OPTIONS = [ COLOR_OPTIONS = [
{'label': 'Category', 'value': 'category'}, {"label": "Category", "value": "category"},
{'label': 'Subcategory', 'value': 'subcategory'}, {"label": "Subcategory", "value": "subcategory"},
{'label': 'Tags', 'value': 'tags'} {"label": "Tags", "value": "tags"},
] ]
DIMENSION_OPTIONS = [ DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}]
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
]
# Default Values # Default Values
DEFAULT_METHOD = 'pca' DEFAULT_METHOD = "pca"
DEFAULT_COLOR_BY = 'category' DEFAULT_COLOR_BY = "category"
DEFAULT_DIMENSIONS = '3d' DEFAULT_DIMENSIONS = "3d"
DEFAULT_SHOW_PROMPTS = ['show'] DEFAULT_SHOW_PROMPTS = ["show"]
# Plot Marker Settings # Plot Marker Settings
DOCUMENT_MARKER_SIZE_2D = 8 DOCUMENT_MARKER_SIZE_2D = 8
DOCUMENT_MARKER_SIZE_3D = 5 DOCUMENT_MARKER_SIZE_3D = 5
PROMPT_MARKER_SIZE_2D = 10 PROMPT_MARKER_SIZE_2D = 10
PROMPT_MARKER_SIZE_3D = 6 PROMPT_MARKER_SIZE_3D = 6
DOCUMENT_MARKER_SYMBOL = 'circle' DOCUMENT_MARKER_SYMBOL = "circle"
PROMPT_MARKER_SYMBOL = 'diamond' PROMPT_MARKER_SYMBOL = "diamond"
DOCUMENT_OPACITY = 1.0 DOCUMENT_OPACITY = 1.0
PROMPT_OPACITY = 0.8 PROMPT_OPACITY = 0.8
# Text Processing # Text Processing
TEXT_PREVIEW_LENGTH = 100 TEXT_PREVIEW_LENGTH = 100
# App Configuration # App Configuration
DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true' DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true"
HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1') HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1")
PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050')) PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050"))
# Bootstrap Theme # Bootstrap Theme
EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css'] EXTERNAL_STYLESHEETS = [
"https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
]
@classmethod @classmethod
def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]: def get_plot_marker_config(
cls, dimensions: str, is_prompt: bool = False
) -> Dict[str, Any]:
if is_prompt: if is_prompt:
size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D size = (
cls.PROMPT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.PROMPT_MARKER_SIZE_2D
)
symbol = cls.PROMPT_MARKER_SYMBOL symbol = cls.PROMPT_MARKER_SYMBOL
opacity = cls.PROMPT_OPACITY opacity = cls.PROMPT_OPACITY
else: else:
size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D size = (
cls.DOCUMENT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.DOCUMENT_MARKER_SIZE_2D
)
symbol = cls.DOCUMENT_MARKER_SYMBOL symbol = cls.DOCUMENT_MARKER_SYMBOL
opacity = cls.DOCUMENT_OPACITY opacity = cls.DOCUMENT_OPACITY
return { return {"size": size, "symbol": symbol, "opacity": opacity}
'size': size,
'symbol': symbol,
'opacity': opacity
}

View File

@@ -6,34 +6,33 @@ from ..models.schemas import Document
class NDJSONParser: class NDJSONParser:
@staticmethod @staticmethod
def parse_upload_contents(contents: str) -> List[Document]: def parse_upload_contents(contents: str) -> List[Document]:
content_type, content_string = contents.split(',') content_type, content_string = contents.split(",")
decoded = base64.b64decode(content_string) decoded = base64.b64decode(content_string)
text_content = decoded.decode('utf-8') text_content = decoded.decode("utf-8")
return NDJSONParser.parse_text(text_content) return NDJSONParser.parse_text(text_content)
@staticmethod @staticmethod
def parse_text(text_content: str) -> List[Document]: def parse_text(text_content: str) -> List[Document]:
documents = [] documents = []
for line in text_content.strip().split('\n'): for line in text_content.strip().split("\n"):
if line.strip(): if line.strip():
doc_dict = json.loads(line) doc_dict = json.loads(line)
doc = NDJSONParser._dict_to_document(doc_dict) doc = NDJSONParser._dict_to_document(doc_dict)
documents.append(doc) documents.append(doc)
return documents return documents
@staticmethod @staticmethod
def _dict_to_document(doc_dict: dict) -> Document: def _dict_to_document(doc_dict: dict) -> Document:
if 'id' not in doc_dict: if "id" not in doc_dict:
doc_dict['id'] = str(uuid.uuid4()) doc_dict["id"] = str(uuid.uuid4())
return Document( return Document(
id=doc_dict['id'], id=doc_dict["id"],
text=doc_dict['text'], text=doc_dict["text"],
embedding=doc_dict['embedding'], embedding=doc_dict["embedding"],
category=doc_dict.get('category'), category=doc_dict.get("category"),
subcategory=doc_dict.get('subcategory'), subcategory=doc_dict.get("subcategory"),
tags=doc_dict.get('tags') tags=doc_dict.get("tags"),
) )

View File

@@ -5,18 +5,19 @@ from .parser import NDJSONParser
class DataProcessor: class DataProcessor:
def __init__(self): def __init__(self):
self.parser = NDJSONParser() self.parser = NDJSONParser()
def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData: def process_upload(
self, contents: str, filename: Optional[str] = None
) -> ProcessedData:
try: try:
documents = self.parser.parse_upload_contents(contents) documents = self.parser.parse_upload_contents(contents)
embeddings = self._extract_embeddings(documents) embeddings = self._extract_embeddings(documents)
return ProcessedData(documents=documents, embeddings=embeddings) return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e: except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def process_text(self, text_content: str) -> ProcessedData: def process_text(self, text_content: str) -> ProcessedData:
try: try:
documents = self.parser.parse_text(text_content) documents = self.parser.parse_text(text_content)
@@ -24,31 +25,35 @@ class DataProcessor:
return ProcessedData(documents=documents, embeddings=embeddings) return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e: except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray: def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
if not documents: if not documents:
return np.array([]) return np.array([])
return np.array([doc.embedding for doc in documents]) return np.array([doc.embedding for doc in documents])
def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]: def combine_data(
self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None
) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
if not doc_data or doc_data.error: if not doc_data or doc_data.error:
raise ValueError("Invalid document data") raise ValueError("Invalid document data")
all_embeddings = doc_data.embeddings all_embeddings = doc_data.embeddings
documents = doc_data.documents documents = doc_data.documents
prompts = None prompts = None
if prompt_data and not prompt_data.error and prompt_data.documents: if prompt_data and not prompt_data.error and prompt_data.documents:
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings]) all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
prompts = prompt_data.documents prompts = prompt_data.documents
return all_embeddings, documents, prompts return all_embeddings, documents, prompts
def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]: def split_reduced_data(
self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
doc_reduced = reduced_embeddings[:n_documents] doc_reduced = reduced_embeddings[:n_documents]
prompt_reduced = None prompt_reduced = None
if n_prompts > 0: if n_prompts > 0:
prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts] prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts]
return doc_reduced, prompt_reduced return doc_reduced, prompt_reduced

View File

@@ -7,88 +7,89 @@ from .schemas import ReducedData
class DimensionalityReducer(ABC): class DimensionalityReducer(ABC):
def __init__(self, n_components: int = 3, random_state: int = 42): def __init__(self, n_components: int = 3, random_state: int = 42):
self.n_components = n_components self.n_components = n_components
self.random_state = random_state self.random_state = random_state
self._reducer = None self._reducer = None
@abstractmethod @abstractmethod
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
pass pass
@abstractmethod @abstractmethod
def get_method_name(self) -> str: def get_method_name(self) -> str:
pass pass
class PCAReducer(DimensionalityReducer): class PCAReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = PCA(n_components=self.n_components) self._reducer = PCA(n_components=self.n_components)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
variance_explained = self._reducer.explained_variance_ratio_ variance_explained = self._reducer.explained_variance_ratio_
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=variance_explained, variance_explained=variance_explained,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "PCA" return "PCA"
class TSNEReducer(DimensionalityReducer): class TSNEReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state) self._reducer = TSNE(
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit(embeddings) reduced = self._reducer.fit(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "t-SNE" return "t-SNE"
class UMAPReducer(DimensionalityReducer): class UMAPReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state) self._reducer = umap.UMAP(
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "UMAP" return "UMAP"
class ReducerFactory: class ReducerFactory:
@staticmethod @staticmethod
def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer: def create_reducer(
method: str, n_components: int = 3, random_state: int = 42
) -> DimensionalityReducer:
method_lower = method.lower() method_lower = method.lower()
if method_lower == 'pca': if method_lower == "pca":
return PCAReducer(n_components=n_components, random_state=random_state) return PCAReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'tsne': elif method_lower == "tsne":
return TSNEReducer(n_components=n_components, random_state=random_state) return TSNEReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'umap': elif method_lower == "umap":
return UMAPReducer(n_components=n_components, random_state=random_state) return UMAPReducer(n_components=n_components, random_state=random_state)
else: else:
raise ValueError(f"Unknown reduction method: {method}") raise ValueError(f"Unknown reduction method: {method}")
@staticmethod @staticmethod
def get_available_methods() -> list: def get_available_methods() -> list:
return ['pca', 'tsne', 'umap'] return ["pca", "tsne", "umap"]

View File

@@ -50,9 +50,11 @@ class PlotData:
coordinates: np.ndarray coordinates: np.ndarray
prompts: Optional[List[Document]] = None prompts: Optional[List[Document]] = None
prompt_coordinates: Optional[np.ndarray] = None prompt_coordinates: Optional[np.ndarray] = None
def __post_init__(self): def __post_init__(self):
if not isinstance(self.coordinates, np.ndarray): if not isinstance(self.coordinates, np.ndarray):
self.coordinates = np.array(self.coordinates) self.coordinates = np.array(self.coordinates)
if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray): if self.prompt_coordinates is not None and not isinstance(
self.prompt_coordinates = np.array(self.prompt_coordinates) self.prompt_coordinates, np.ndarray
):
self.prompt_coordinates = np.array(self.prompt_coordinates)

View File

@@ -3,58 +3,60 @@ from ...data.processor import DataProcessor
class DataProcessingCallbacks: class DataProcessingCallbacks:
def __init__(self): def __init__(self):
self.processor = DataProcessor() self.processor = DataProcessor()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('processed-data', 'data'), Output("processed-data", "data"),
Input('upload-data', 'contents'), Input("upload-data", "contents"),
State('upload-data', 'filename') State("upload-data", "filename"),
) )
def process_uploaded_file(contents, filename): def process_uploaded_file(contents, filename):
if contents is None: if contents is None:
return None return None
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
return {'error': processed_data.error} return {"error": processed_data.error}
return { return {
'documents': [self._document_to_dict(doc) for doc in processed_data.documents], "documents": [
'embeddings': processed_data.embeddings.tolist() self._document_to_dict(doc) for doc in processed_data.documents
],
"embeddings": processed_data.embeddings.tolist(),
} }
@callback( @callback(
Output('processed-prompts', 'data'), Output("processed-prompts", "data"),
Input('upload-prompts', 'contents'), Input("upload-prompts", "contents"),
State('upload-prompts', 'filename') State("upload-prompts", "filename"),
) )
def process_uploaded_prompts(contents, filename): def process_uploaded_prompts(contents, filename):
if contents is None: if contents is None:
return None return None
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
return {'error': processed_data.error} return {"error": processed_data.error}
return { return {
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents], "prompts": [
'embeddings': processed_data.embeddings.tolist() self._document_to_dict(doc) for doc in processed_data.documents
],
"embeddings": processed_data.embeddings.tolist(),
} }
@staticmethod @staticmethod
def _document_to_dict(doc): def _document_to_dict(doc):
return { return {
'id': doc.id, "id": doc.id,
'text': doc.text, "text": doc.text,
'embedding': doc.embedding, "embedding": doc.embedding,
'category': doc.category, "category": doc.category,
'subcategory': doc.subcategory, "subcategory": doc.subcategory,
'tags': doc.tags "tags": doc.tags,
} }

View File

@@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc
class InteractionCallbacks: class InteractionCallbacks:
def __init__(self): def __init__(self):
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('point-details', 'children'), Output("point-details", "children"),
Input('embedding-plot', 'clickData'), Input("embedding-plot", "clickData"),
[State('processed-data', 'data'), [State("processed-data", "data"), State("processed-prompts", "data")],
State('processed-prompts', 'data')]
) )
def display_click_data(clickData, data, prompts_data): def display_click_data(clickData, data, prompts_data):
if not clickData or not data: if not clickData or not data:
return "Click on a point to see details" return "Click on a point to see details"
point_data = clickData['points'][0] point_data = clickData["points"][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents') trace_name = point_data.get("fullData", {}).get("name", "Documents")
if 'pointIndex' in point_data: if "pointIndex" in point_data:
point_index = point_data['pointIndex'] point_index = point_data["pointIndex"]
elif 'pointNumber' in point_data: elif "pointNumber" in point_data:
point_index = point_data['pointNumber'] point_index = point_data["pointNumber"]
else: else:
return "Could not identify clicked point" return "Could not identify clicked point"
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data: if (
item = prompts_data['prompts'][point_index] trace_name.startswith("Prompts")
item_type = 'Prompt' and prompts_data
and "prompts" in prompts_data
):
item = prompts_data["prompts"][point_index]
item_type = "Prompt"
else: else:
item = data['documents'][point_index] item = data["documents"][point_index]
item_type = 'Document' item_type = "Document"
return self._create_detail_card(item, item_type) return self._create_detail_card(item, item_type)
@callback( @callback(
[Output('processed-data', 'data', allow_duplicate=True), [
Output('processed-prompts', 'data', allow_duplicate=True), Output("processed-data", "data", allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)], Output("processed-prompts", "data", allow_duplicate=True),
Input('reset-button', 'n_clicks'), Output("point-details", "children", allow_duplicate=True),
prevent_initial_call=True ],
Input("reset-button", "n_clicks"),
prevent_initial_call=True,
) )
def reset_data(n_clicks): def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0: if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details" return None, None, "Click on a point to see details"
@staticmethod @staticmethod
def _create_detail_card(item, item_type): def _create_detail_card(item, item_type):
return dbc.Card([ return dbc.Card(
dbc.CardBody([ [
html.H5(f"{item_type}: {item['id']}", className="card-title"), dbc.CardBody(
html.P(f"Text: {item['text']}", className="card-text"), [
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"), html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"), html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"), html.P(
html.P(f"Type: {item_type}", className="card-text text-muted") f"Category: {item.get('category', 'Unknown')}",
]) className="card-text",
]) ),
html.P(
f"Subcategory: {item.get('subcategory', 'Unknown')}",
className="card-text",
),
html.P(
f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
className="card-text",
),
html.P(f"Type: {item_type}", className="card-text text-muted"),
]
)
]
)

View File

@@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory
class VisualizationCallbacks: class VisualizationCallbacks:
def __init__(self): def __init__(self):
self.plot_factory = PlotFactory() self.plot_factory = PlotFactory()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('embedding-plot', 'figure'), Output("embedding-plot", "figure"),
[Input('processed-data', 'data'), [
Input('processed-prompts', 'data'), Input("processed-data", "data"),
Input('method-dropdown', 'value'), Input("processed-prompts", "data"),
Input('color-dropdown', 'value'), Input("method-dropdown", "value"),
Input('dimension-toggle', 'value'), Input("color-dropdown", "value"),
Input('show-prompts-toggle', 'value')] Input("dimension-toggle", "value"),
Input("show-prompts-toggle", "value"),
],
) )
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data: if not data or "error" in data:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization", text="Upload a valid NDJSON file to see visualization",
xref="paper", yref="paper", xref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle', yref="paper",
showarrow=False, font=dict(size=16) x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
try: try:
doc_embeddings = np.array(data['embeddings']) doc_embeddings = np.array(data["embeddings"])
all_embeddings = doc_embeddings all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts') has_prompts = (
prompts_data
and "error" not in prompts_data
and prompts_data.get("prompts")
)
if has_prompts: if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings']) prompt_embeddings = np.array(prompts_data["embeddings"])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2 n_components = 3 if dimensions == "3d" else 2
reducer = ReducerFactory.create_reducer(method, n_components=n_components) reducer = ReducerFactory.create_reducer(
method, n_components=n_components
)
reduced_data = reducer.fit_transform(all_embeddings) reduced_data = reducer.fit_transform(all_embeddings)
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)] doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
prompt_reduced = None prompt_reduced = None
if has_prompts: if has_prompts:
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):] prompt_reduced = reduced_data.reduced_embeddings[
len(doc_embeddings) :
documents = [self._dict_to_document(doc) for doc in data['documents']] ]
documents = [self._dict_to_document(doc) for doc in data["documents"]]
prompts = None prompts = None
if has_prompts: if has_prompts:
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']] prompts = [
self._dict_to_document(prompt)
for prompt in prompts_data["prompts"]
]
plot_data = PlotData( plot_data = PlotData(
documents=documents, documents=documents,
coordinates=doc_reduced, coordinates=doc_reduced,
prompts=prompts, prompts=prompts,
prompt_coordinates=prompt_reduced prompt_coordinates=prompt_reduced,
) )
return self.plot_factory.create_plot( return self.plot_factory.create_plot(
plot_data, dimensions, color_by, reduced_data.method, show_prompts plot_data, dimensions, color_by, reduced_data.method, show_prompts
) )
except Exception as e: except Exception as e:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text=f"Error creating visualization: {str(e)}", text=f"Error creating visualization: {str(e)}",
xref="paper", yref="paper", xref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle', yref="paper",
showarrow=False, font=dict(size=16) x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
@staticmethod @staticmethod
def _dict_to_document(doc_dict): def _dict_to_document(doc_dict):
return Document( return Document(
id=doc_dict['id'], id=doc_dict["id"],
text=doc_dict['text'], text=doc_dict["text"],
embedding=doc_dict['embedding'], embedding=doc_dict["embedding"],
category=doc_dict.get('category'), category=doc_dict.get("category"),
subcategory=doc_dict.get('subcategory'), subcategory=doc_dict.get("subcategory"),
tags=doc_dict.get('tags', []) tags=doc_dict.get("tags", []),
) )

View File

@@ -4,79 +4,81 @@ from .upload import UploadComponent
class SidebarComponent: class SidebarComponent:
def __init__(self): def __init__(self):
self.upload_component = UploadComponent() self.upload_component = UploadComponent()
def create_layout(self): def create_layout(self):
return dbc.Col([ return dbc.Col(
html.H5("Upload Data", className="mb-3"), [
self.upload_component.create_data_upload(), html.H5("Upload Data", className="mb-3"),
self.upload_component.create_prompts_upload(), self.upload_component.create_data_upload(),
self.upload_component.create_reset_button(), self.upload_component.create_prompts_upload(),
self.upload_component.create_reset_button(),
html.H5("Visualization Controls", className="mb-3"), html.H5("Visualization Controls", className="mb-3"),
self._create_method_dropdown(), self._create_method_dropdown(),
self._create_color_dropdown(), self._create_color_dropdown(),
self._create_dimension_toggle(), self._create_dimension_toggle(),
self._create_prompts_toggle(), self._create_prompts_toggle(),
html.H5("Point Details", className="mb-3"),
html.H5("Point Details", className="mb-3"), html.Div(
html.Div(id='point-details', children="Click on a point to see details") id="point-details", children="Click on a point to see details"
),
], width=3, style={'padding-right': '20px'}) ],
width=3,
style={"padding-right": "20px"},
)
def _create_method_dropdown(self): def _create_method_dropdown(self):
return [ return [
dbc.Label("Method:"), dbc.Label("Method:"),
dcc.Dropdown( dcc.Dropdown(
id='method-dropdown', id="method-dropdown",
options=[ options=[
{'label': 'PCA', 'value': 'pca'}, {"label": "PCA", "value": "pca"},
{'label': 't-SNE', 'value': 'tsne'}, {"label": "t-SNE", "value": "tsne"},
{'label': 'UMAP', 'value': 'umap'} {"label": "UMAP", "value": "umap"},
], ],
value='pca', value="pca",
style={'margin-bottom': '15px'} style={"margin-bottom": "15px"},
) ),
] ]
def _create_color_dropdown(self): def _create_color_dropdown(self):
return [ return [
dbc.Label("Color by:"), dbc.Label("Color by:"),
dcc.Dropdown( dcc.Dropdown(
id='color-dropdown', id="color-dropdown",
options=[ options=[
{'label': 'Category', 'value': 'category'}, {"label": "Category", "value": "category"},
{'label': 'Subcategory', 'value': 'subcategory'}, {"label": "Subcategory", "value": "subcategory"},
{'label': 'Tags', 'value': 'tags'} {"label": "Tags", "value": "tags"},
], ],
value='category', value="category",
style={'margin-bottom': '15px'} style={"margin-bottom": "15px"},
) ),
] ]
def _create_dimension_toggle(self): def _create_dimension_toggle(self):
return [ return [
dbc.Label("Dimensions:"), dbc.Label("Dimensions:"),
dcc.RadioItems( dcc.RadioItems(
id='dimension-toggle', id="dimension-toggle",
options=[ options=[
{'label': '2D', 'value': '2d'}, {"label": "2D", "value": "2d"},
{'label': '3D', 'value': '3d'} {"label": "3D", "value": "3d"},
], ],
value='3d', value="3d",
style={'margin-bottom': '20px'} style={"margin-bottom": "20px"},
) ),
] ]
def _create_prompts_toggle(self): def _create_prompts_toggle(self):
return [ return [
dbc.Label("Show Prompts:"), dbc.Label("Show Prompts:"),
dcc.Checklist( dcc.Checklist(
id='show-prompts-toggle', id="show-prompts-toggle",
options=[{'label': 'Show prompts on plot', 'value': 'show'}], options=[{"label": "Show prompts on plot", "value": "show"}],
value=['show'], value=["show"],
style={'margin-bottom': '20px'} style={"margin-bottom": "20px"},
) ),
] ]

View File

@@ -3,58 +3,51 @@ import dash_bootstrap_components as dbc
class UploadComponent: class UploadComponent:
@staticmethod @staticmethod
def create_data_upload(): def create_data_upload():
return dcc.Upload( return dcc.Upload(
id='upload-data', id="upload-data",
children=html.Div([ children=html.Div(["Drag and Drop or ", html.A("Select Files")]),
'Drag and Drop or ',
html.A('Select Files')
]),
style={ style={
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px' "margin-bottom": "20px",
}, },
multiple=False multiple=False,
) )
@staticmethod @staticmethod
def create_prompts_upload(): def create_prompts_upload():
return dcc.Upload( return dcc.Upload(
id='upload-prompts', id="upload-prompts",
children=html.Div([ children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]),
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={ style={
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px', "margin-bottom": "20px",
'borderColor': '#28a745' "borderColor": "#28a745",
}, },
multiple=False multiple=False,
) )
@staticmethod @staticmethod
def create_reset_button(): def create_reset_button():
return dbc.Button( return dbc.Button(
"Reset All Data", "Reset All Data",
id='reset-button', id="reset-button",
color='danger', color="danger",
outline=True, outline=True,
size='sm', size="sm",
className='mb-3', className="mb-3",
style={'width': '100%'} style={"width": "100%"},
) )

View File

@@ -4,41 +4,43 @@ from .components.sidebar import SidebarComponent
class AppLayout: class AppLayout:
def __init__(self): def __init__(self):
self.sidebar = SidebarComponent() self.sidebar = SidebarComponent()
def create_layout(self): def create_layout(self):
return dbc.Container([ return dbc.Container(
self._create_header(), [self._create_header(), self._create_main_content(), self._create_stores()],
self._create_main_content(), fluid=True,
self._create_stores() )
], fluid=True)
def _create_header(self): def _create_header(self):
return dbc.Row([ return dbc.Row(
dbc.Col([ [
html.H1("EmbeddingBuddy", className="text-center mb-4"), dbc.Col(
], width=12) [
]) html.H1("EmbeddingBuddy", className="text-center mb-4"),
],
width=12,
)
]
)
def _create_main_content(self): def _create_main_content(self):
return dbc.Row([ return dbc.Row(
self.sidebar.create_layout(), [self.sidebar.create_layout(), self._create_visualization_area()]
self._create_visualization_area() )
])
def _create_visualization_area(self): def _create_visualization_area(self):
return dbc.Col([ return dbc.Col(
dcc.Graph( [
id='embedding-plot', dcc.Graph(
style={'height': '85vh', 'width': '100%'}, id="embedding-plot",
config={'responsive': True, 'displayModeBar': True} style={"height": "85vh", "width": "100%"},
) config={"responsive": True, "displayModeBar": True},
], width=9) )
],
width=9,
)
def _create_stores(self): def _create_stores(self):
return [ return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")]
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
]

View File

@@ -4,30 +4,33 @@ from ..models.schemas import Document
class ColorMapper: class ColorMapper:
@staticmethod @staticmethod
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]: def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
if color_by == 'category': if color_by == "category":
return [doc.category for doc in documents] return [doc.category for doc in documents]
elif color_by == 'subcategory': elif color_by == "subcategory":
return [doc.subcategory for doc in documents] return [doc.subcategory for doc in documents]
elif color_by == 'tags': elif color_by == "tags":
return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents] return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents]
else: else:
return ['All'] * len(documents) return ["All"] * len(documents)
@staticmethod @staticmethod
def to_grayscale_hex(color_str: str) -> str: def to_grayscale_hex(color_str: str) -> str:
try: try:
if color_str.startswith('#'): if color_str.startswith("#"):
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5)) rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5))
else: else:
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0]) rgb = pc.hex_to_rgb(
pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0]
)
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, gray_rgb = (
gray_value * 0.7 + rgb[1] * 0.3, gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3) gray_value * 0.7 + rgb[1] * 0.3,
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})' gray_value * 0.7 + rgb[2] * 0.3,
except: # noqa: E722 )
return 'rgb(128,128,128)' return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})"
except: # noqa: E722
return "rgb(128,128,128)"

View File

@@ -7,139 +7,172 @@ from .colors import ColorMapper
class PlotFactory: class PlotFactory:
def __init__(self): def __init__(self):
self.color_mapper = ColorMapper() self.color_mapper = ColorMapper()
def create_plot(self, plot_data: PlotData, dimensions: str = '3d', def create_plot(
color_by: str = 'category', method: str = 'PCA', self,
show_prompts: Optional[List[str]] = None) -> go.Figure: plot_data: PlotData,
dimensions: str = "3d",
if plot_data.prompts and show_prompts and 'show' in show_prompts: color_by: str = "category",
method: str = "PCA",
show_prompts: Optional[List[str]] = None,
) -> go.Figure:
if plot_data.prompts and show_prompts and "show" in show_prompts:
return self._create_dual_plot(plot_data, dimensions, color_by, method) return self._create_dual_plot(plot_data, dimensions, color_by, method)
else: else:
return self._create_single_plot(plot_data, dimensions, color_by, method) return self._create_single_plot(plot_data, dimensions, color_by, method)
def _create_single_plot(self, plot_data: PlotData, dimensions: str, def _create_single_plot(
color_by: str, method: str) -> go.Figure: self, plot_data: PlotData, dimensions: str, color_by: str, method: str
df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) ) -> go.Figure:
color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) df = self._prepare_dataframe(
plot_data.documents, plot_data.coordinates, dimensions
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] )
color_values = self.color_mapper.create_color_mapping(
if dimensions == '3d': plot_data.documents, color_by
)
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
if dimensions == "3d":
fig = px.scatter_3d( fig = px.scatter_3d(
df, x='dim_1', y='dim_2', z='dim_3', df,
x="dim_1",
y="dim_2",
z="dim_3",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f'3D Embedding Visualization - {method} (colored by {color_by})' title=f"3D Embedding Visualization - {method} (colored by {color_by})",
) )
fig.update_traces(marker=dict(size=5)) fig.update_traces(marker=dict(size=5))
else: else:
fig = px.scatter( fig = px.scatter(
df, x='dim_1', y='dim_2', df,
x="dim_1",
y="dim_2",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f'2D Embedding Visualization - {method} (colored by {color_by})' title=f"2D Embedding Visualization - {method} (colored by {color_by})",
) )
fig.update_traces(marker=dict(size=8)) fig.update_traces(marker=dict(size=8))
fig.update_layout( fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0))
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig return fig
def _create_dual_plot(self, plot_data: PlotData, dimensions: str, def _create_dual_plot(
color_by: str, method: str) -> go.Figure: self, plot_data: PlotData, dimensions: str, color_by: str, method: str
) -> go.Figure:
fig = go.Figure() fig = go.Figure()
doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) doc_df = self._prepare_dataframe(
doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) plot_data.documents, plot_data.coordinates, dimensions
)
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] doc_color_values = self.color_mapper.create_color_mapping(
plot_data.documents, color_by
if dimensions == '3d': )
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
if dimensions == "3d":
doc_fig = px.scatter_3d( doc_fig = px.scatter_3d(
doc_df, x='dim_1', y='dim_2', z='dim_3', doc_df,
x="dim_1",
y="dim_2",
z="dim_3",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
else: else:
doc_fig = px.scatter( doc_fig = px.scatter(
doc_df, x='dim_1', y='dim_2', doc_df,
x="dim_1",
y="dim_2",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
for trace in doc_fig.data: for trace in doc_fig.data:
trace.name = f'Documents - {trace.name}' trace.name = f"Documents - {trace.name}"
if dimensions == '3d': if dimensions == "3d":
trace.marker.size = 5 trace.marker.size = 5
trace.marker.symbol = 'circle' trace.marker.symbol = "circle"
else: else:
trace.marker.size = 8 trace.marker.size = 8
trace.marker.symbol = 'circle' trace.marker.symbol = "circle"
trace.marker.opacity = 1.0 trace.marker.opacity = 1.0
fig.add_trace(trace) fig.add_trace(trace)
if plot_data.prompts and plot_data.prompt_coordinates is not None: if plot_data.prompts and plot_data.prompt_coordinates is not None:
prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions) prompt_df = self._prepare_dataframe(
prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by) plot_data.prompts, plot_data.prompt_coordinates, dimensions
)
if dimensions == '3d': prompt_color_values = self.color_mapper.create_color_mapping(
plot_data.prompts, color_by
)
if dimensions == "3d":
prompt_fig = px.scatter_3d( prompt_fig = px.scatter_3d(
prompt_df, x='dim_1', y='dim_2', z='dim_3', prompt_df,
x="dim_1",
y="dim_2",
z="dim_3",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
else: else:
prompt_fig = px.scatter( prompt_fig = px.scatter(
prompt_df, x='dim_1', y='dim_2', prompt_df,
x="dim_1",
y="dim_2",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
for trace in prompt_fig.data: for trace in prompt_fig.data:
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str): if hasattr(trace.marker, "color") and isinstance(
trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color) trace.marker.color, str
):
trace.name = f'Prompts - {trace.name}' trace.marker.color = self.color_mapper.to_grayscale_hex(
if dimensions == '3d': trace.marker.color
)
trace.name = f"Prompts - {trace.name}"
if dimensions == "3d":
trace.marker.size = 6 trace.marker.size = 6
trace.marker.symbol = 'diamond' trace.marker.symbol = "diamond"
else: else:
trace.marker.size = 10 trace.marker.size = 10
trace.marker.symbol = 'diamond' trace.marker.symbol = "diamond"
trace.marker.opacity = 0.8 trace.marker.opacity = 0.8
fig.add_trace(trace) fig.add_trace(trace)
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})' title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})"
fig.update_layout( fig.update_layout(
title=title, title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
) )
return fig return fig
def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame: def _prepare_dataframe(
self, documents: List[Document], coordinates, dimensions: str
) -> pd.DataFrame:
df_data = [] df_data = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
row = { row = {
'id': doc.id, "id": doc.id,
'text': doc.text, "text": doc.text,
'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text, "text_preview": doc.text[:100] + "..."
'category': doc.category, if len(doc.text) > 100
'subcategory': doc.subcategory, else doc.text,
'tags_str': ', '.join(doc.tags) if doc.tags else 'None', "category": doc.category,
'dim_1': coordinates[i, 0], "subcategory": doc.subcategory,
'dim_2': coordinates[i, 1], "tags_str": ", ".join(doc.tags) if doc.tags else "None",
"dim_1": coordinates[i, 0],
"dim_2": coordinates[i, 1],
} }
if dimensions == '3d': if dimensions == "3d":
row['dim_3'] = coordinates[i, 2] row["dim_3"] = coordinates[i, 2]
df_data.append(row) df_data.append(row)
return pd.DataFrame(df_data) return pd.DataFrame(df_data)

View File

@@ -6,62 +6,64 @@ from src.embeddingbuddy.models.schemas import Document
class TestNDJSONParser: class TestNDJSONParser:
def test_parse_text_basic(self): def test_parse_text_basic(self):
text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}' text_content = (
'{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
)
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].id == "test1" assert documents[0].id == "test1"
assert documents[0].text == "Hello world" assert documents[0].text == "Hello world"
assert documents[0].embedding == [0.1, 0.2, 0.3] assert documents[0].embedding == [0.1, 0.2, 0.3]
def test_parse_text_with_metadata(self): def test_parse_text_with_metadata(self):
text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}' text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert documents[0].category == "greeting" assert documents[0].category == "greeting"
assert documents[0].tags == ["test"] assert documents[0].tags == ["test"]
def test_parse_text_missing_id(self): def test_parse_text_missing_id(self):
text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}' text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].id is not None # Should be auto-generated assert documents[0].id is not None # Should be auto-generated
class TestDataProcessor: class TestDataProcessor:
def test_extract_embeddings(self): def test_extract_embeddings(self):
documents = [ documents = [
Document(id="1", text="test1", embedding=[0.1, 0.2]), Document(id="1", text="test1", embedding=[0.1, 0.2]),
Document(id="2", text="test2", embedding=[0.3, 0.4]) Document(id="2", text="test2", embedding=[0.3, 0.4]),
] ]
processor = DataProcessor() processor = DataProcessor()
embeddings = processor._extract_embeddings(documents) embeddings = processor._extract_embeddings(documents)
assert embeddings.shape == (2, 2) assert embeddings.shape == (2, 2)
assert np.allclose(embeddings[0], [0.1, 0.2]) assert np.allclose(embeddings[0], [0.1, 0.2])
assert np.allclose(embeddings[1], [0.3, 0.4]) assert np.allclose(embeddings[1], [0.3, 0.4])
def test_combine_data(self): def test_combine_data(self):
from src.embeddingbuddy.models.schemas import ProcessedData from src.embeddingbuddy.models.schemas import ProcessedData
doc_data = ProcessedData( doc_data = ProcessedData(
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])], documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
embeddings=np.array([[0.1, 0.2]]) embeddings=np.array([[0.1, 0.2]]),
) )
prompt_data = ProcessedData( prompt_data = ProcessedData(
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])], documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
embeddings=np.array([[0.3, 0.4]]) embeddings=np.array([[0.3, 0.4]]),
) )
processor = DataProcessor() processor = DataProcessor()
all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data) all_embeddings, documents, prompts = processor.combine_data(
doc_data, prompt_data
)
assert all_embeddings.shape == (2, 2) assert all_embeddings.shape == (2, 2)
assert len(documents) == 1 assert len(documents) == 1
assert len(prompts) == 1 assert len(prompts) == 1
@@ -70,4 +72,4 @@ class TestDataProcessor:
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,89 +1,90 @@
import pytest import pytest
import numpy as np import numpy as np
from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer from src.embeddingbuddy.models.reducers import (
ReducerFactory,
PCAReducer,
TSNEReducer,
UMAPReducer,
)
class TestReducerFactory: class TestReducerFactory:
def test_create_pca_reducer(self): def test_create_pca_reducer(self):
reducer = ReducerFactory.create_reducer('pca', n_components=2) reducer = ReducerFactory.create_reducer("pca", n_components=2)
assert isinstance(reducer, PCAReducer) assert isinstance(reducer, PCAReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_create_tsne_reducer(self): def test_create_tsne_reducer(self):
reducer = ReducerFactory.create_reducer('tsne', n_components=3) reducer = ReducerFactory.create_reducer("tsne", n_components=3)
assert isinstance(reducer, TSNEReducer) assert isinstance(reducer, TSNEReducer)
assert reducer.n_components == 3 assert reducer.n_components == 3
def test_create_umap_reducer(self): def test_create_umap_reducer(self):
reducer = ReducerFactory.create_reducer('umap', n_components=2) reducer = ReducerFactory.create_reducer("umap", n_components=2)
assert isinstance(reducer, UMAPReducer) assert isinstance(reducer, UMAPReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_invalid_method(self): def test_invalid_method(self):
with pytest.raises(ValueError, match="Unknown reduction method"): with pytest.raises(ValueError, match="Unknown reduction method"):
ReducerFactory.create_reducer('invalid_method') ReducerFactory.create_reducer("invalid_method")
def test_available_methods(self): def test_available_methods(self):
methods = ReducerFactory.get_available_methods() methods = ReducerFactory.get_available_methods()
assert 'pca' in methods assert "pca" in methods
assert 'tsne' in methods assert "tsne" in methods
assert 'umap' in methods assert "umap" in methods
class TestPCAReducer: class TestPCAReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(100, 512) embeddings = np.random.rand(100, 512)
reducer = PCAReducer(n_components=2) reducer = PCAReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (100, 2) assert result.reduced_embeddings.shape == (100, 2)
assert result.variance_explained is not None assert result.variance_explained is not None
assert result.method == "PCA" assert result.method == "PCA"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = PCAReducer() reducer = PCAReducer()
assert reducer.get_method_name() == "PCA" assert reducer.get_method_name() == "PCA"
class TestTSNEReducer: class TestTSNEReducer:
def test_fit_transform_small_dataset(self): def test_fit_transform_small_dataset(self):
embeddings = np.random.rand(30, 10) # Small dataset for faster testing embeddings = np.random.rand(30, 10) # Small dataset for faster testing
reducer = TSNEReducer(n_components=2) reducer = TSNEReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (30, 2) assert result.reduced_embeddings.shape == (30, 2)
assert result.variance_explained is None # t-SNE doesn't provide this assert result.variance_explained is None # t-SNE doesn't provide this
assert result.method == "t-SNE" assert result.method == "t-SNE"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = TSNEReducer() reducer = TSNEReducer()
assert reducer.get_method_name() == "t-SNE" assert reducer.get_method_name() == "t-SNE"
class TestUMAPReducer: class TestUMAPReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(50, 10) embeddings = np.random.rand(50, 10)
reducer = UMAPReducer(n_components=2) reducer = UMAPReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (50, 2) assert result.reduced_embeddings.shape == (50, 2)
assert result.variance_explained is None # UMAP doesn't provide this assert result.variance_explained is None # UMAP doesn't provide this
assert result.method == "UMAP" assert result.method == "UMAP"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = UMAPReducer() reducer = UMAPReducer()
assert reducer.get_method_name() == "UMAP" assert reducer.get_method_name() == "UMAP"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])