Add realtime response when prompting and fixed swagger docs

master
Anulax ago%!(EXTRA string=3 months)
parent 8f8d6db35d
commit 6d65ab7030
  1. 4
      api/templates/drf-yasg/redoc.html
  2. 2
      api/urls.py
  3. 22
      api/utils.py
  4. 87
      api/views.py

@ -4,8 +4,8 @@
<link rel="icon" type="image/ico" href="/static/favicon.ico"/> <link rel="icon" type="image/ico" href="/static/favicon.ico"/>
<style> <style>
img[alt="Botzilla logo"] { img[alt="Botzilla logo"] {
margin: auto;
max-height: 80px !important; max-height: 60px !important;
width: auto; width: auto;
padding-top: 3px !important; padding-top: 3px !important;
padding-bottom: 3px !important; padding-bottom: 3px !important;

@ -22,6 +22,7 @@ from rest_framework_simplejwt.views import (
) )
from .views import MasterKeyTokenObtainView from .views import MasterKeyTokenObtainView
from .views import ConversationView from .views import ConversationView
from .views import ConversationActions
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
router = DefaultRouter() router = DefaultRouter()
@ -30,5 +31,6 @@ urlpatterns = [
path('token/', MasterKeyTokenObtainView.as_view(), name='token_obtain_pair'), path('token/', MasterKeyTokenObtainView.as_view(), name='token_obtain_pair'),
path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
path('token/verify/', TokenVerifyView.as_view(), name='token_verify'), path('token/verify/', TokenVerifyView.as_view(), name='token_verify'),
path('conversations/<int:id>/prompt/', ConversationActions.as_view()),
path('', include(router.urls)), path('', include(router.urls)),
] ]

@ -0,0 +1,22 @@
from drf_yasg.utils import swagger_auto_schema
def add_swagger_summaries(viewset_class):
"""Add standard swagger summaries to a ModelViewSet class"""
model_name = viewset_class.serializer_class.Meta.model.__name__.lower()
# Apply decorators to the class methods
for action, template in {
'list': f"List all {model_name}s",
'create': f"Create new {model_name}",
'retrieve': f"Get specific {model_name}",
'update': f"Update {model_name} completely",
'partial_update': f"Update {model_name} partially",
'destroy': f"Delete {model_name}",
}.items():
if hasattr(viewset_class, action):
method = getattr(viewset_class, action)
if not hasattr(method, '_swagger_auto_schema'):
setattr(viewset_class, action,
swagger_auto_schema(operation_summary=template)(method))
return viewset_class

@ -1,6 +1,8 @@
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi from drf_yasg import openapi
import base64
import json
from rest_framework_simplejwt.views import TokenObtainPairView from rest_framework_simplejwt.views import TokenObtainPairView
from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.tokens import RefreshToken
from django.views.decorators.http import require_http_methods from django.views.decorators.http import require_http_methods
@ -11,15 +13,24 @@ from django.utils import timezone
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from .models import MasterKey from .models import MasterKey
from .models import Conversation from .models import Conversation
from django.http import StreamingHttpResponse
from .models import Message from .models import Message
from .serializers import ConversationSerializer, MessageSerializer from .serializers import ConversationSerializer, MessageSerializer
from rest_framework import viewsets, status from rest_framework import viewsets, status
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from .utils import add_swagger_summaries
from django.views import View
import ollama import ollama
from asgiref.sync import sync_to_async
from django.utils.decorators import method_decorator
import asyncio
from django.views.decorators.csrf import csrf_exempt
User = get_user_model() User = get_user_model()
@add_swagger_summaries
class ConversationView(viewsets.ModelViewSet): class ConversationView(viewsets.ModelViewSet):
queryset = Conversation.objects.all() queryset = Conversation.objects.all()
@ -36,7 +47,8 @@ class ConversationView(viewsets.ModelViewSet):
@swagger_auto_schema( @swagger_auto_schema(
method='get', method='get',
operation_description="Get messages of the conversation", operation_description="Get all the messages of the conversation",
operation_summary="Get all messages",
responses={ responses={
200: openapi.Response('List of messages', MessageSerializer(many=True)), 200: openapi.Response('List of messages', MessageSerializer(many=True)),
400: 'Bad Request', 400: 'Bad Request',
@ -57,9 +69,11 @@ class ConversationView(viewsets.ModelViewSet):
messages = conversation.messages.all() messages = conversation.messages.all()
return Response(data=list(messages.values())) return Response(data=list(messages.values()))
@method_decorator(csrf_exempt, name='dispatch')
class ConversationActions(View):
@swagger_auto_schema( @swagger_auto_schema(
method='post',
operation_description="Discutes with the ai", operation_description="Discutes with the ai",
operation_summary="Make a new prompt",
request_body=openapi.Schema( request_body=openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
properties={ properties={
@ -72,43 +86,58 @@ class ConversationView(viewsets.ModelViewSet):
400: 'Bad Request', 400: 'Bad Request',
} }
) )
@action(detail=True, methods=['post']) @sync_to_async
def prompt(self, request, pk=None): def post(self, request, *args, **kwargs):
conversation = self.get_object() conversation = Conversation.objects.get(pk=self.kwargs['id'])
messages = [{ data = json.loads(request.body)
"role": "system", messages = []#{
"content": """ # "role": "system",
You must strictly refuse to engage with ANY of the following: # "content": """
1. Violence or harm (even fictional or hypothetical scenarios) # You must strictly refuse to engage with ANY of the following categories:
2. ANY explicit, suggestive, or romantic content # 1. Violence or harm (even fictional or hypothetical scenarios)
3. Controversial political topics # 2. ANY explicit, suggestive, or romantic content
4. ANY content that could potentially be misused # 3. Controversial political topics
5. Medical, legal, or financial advice # 4. ANY content that could potentially be illegal
6. Personal information or privacy violations # 5. Medical, legal, or financial advice
7. Anything that could be remotely offensive to anyone # 6. Personal information or privacy violations
# 7. Anything that could be remotely offensive to anyone
If you detect such content, immediately respond with: # If you detect such content, immediately respond with:
"I cannot assist with that request as it appears to be inappropriate. I'm designed to be helpful, but within strict ethical boundaries. Is there something else I can help you with?" # "I cannot assist with that request as it appears to be inappropriate. I'm designed to be helpful, but within strict ethical boundaries. Is there something else I can help you with?"
""" # """
}] # }]
for message in conversation.messages.all(): for message in conversation.messages.all():
if message: if message:
messages.append({ messages.append({
"role": message.role, "role": message.role,
"content": message.content "content": message.content
}) })
messages.append({ messages.append({
"role": "user", "role": "user",
"content": request.data.get("content", "") "content": data['content'] or ""
}) })
res = ollama.chat(model="gemma3", messages=messages) Message(role="user", content=data['content'] or "", conversation=conversation).save()
Message(role="user", content=request.data.get("content", ""), conversation=conversation).save() ai_message = Message(role="assistant", content="", conversation=conversation)
Message(role="assistant", content=res['message']['content'], conversation=conversation).save()
return Response(data={
"content": res['message']['content']
})
@sync_to_async
def save_message():
ai_message.save()
async def chat_event_stream():
message = ""
try:
stream = ollama.chat(model="llama2-uncensored", messages=messages, stream=True)
for chunk in stream:
message = chunk['message']['content']
#print(message, base64.b64encode(message.encode("utf-8")))
ai_message.content += message
yield f"{base64.b64encode(message.encode("utf-8")).decode("utf-8")}"
finally:
await save_message()
response = StreamingHttpResponse(chat_event_stream(), content_type='text/event-stream')
response['Cache-Control'] = 'no-cache'
response['X-Accel-Buffering'] = 'no'
return response

Loading…
Cancel
Save