|
|
|
@ -1,31 +1,27 @@ |
|
|
|
|
from drf_yasg.utils import swagger_auto_schema |
|
|
|
|
from drf_yasg import openapi |
|
|
|
|
|
|
|
|
|
import base64 |
|
|
|
|
import json |
|
|
|
|
from rest_framework_simplejwt.views import TokenObtainPairView |
|
|
|
|
from rest_framework_simplejwt.tokens import RefreshToken |
|
|
|
|
from django.views.decorators.http import require_http_methods |
|
|
|
|
from rest_framework.response import Response |
|
|
|
|
from rest_framework import status |
|
|
|
|
from rest_framework.permissions import AllowAny, IsAuthenticated |
|
|
|
|
from rest_framework import status |
|
|
|
|
from rest_framework.decorators import action, api_view |
|
|
|
|
from adrf import viewsets |
|
|
|
|
from django.views.decorators.http import require_POST |
|
|
|
|
from django.utils import timezone |
|
|
|
|
from django.contrib.auth import get_user_model |
|
|
|
|
from .models import MasterKey |
|
|
|
|
from .models import Conversation |
|
|
|
|
from .models import MasterKey, Conversation |
|
|
|
|
from django.http import StreamingHttpResponse |
|
|
|
|
from .models import Message |
|
|
|
|
from .serializers import ConversationSerializer, MessageSerializer |
|
|
|
|
from rest_framework import viewsets, status |
|
|
|
|
from rest_framework.decorators import action |
|
|
|
|
from rest_framework.response import Response |
|
|
|
|
|
|
|
|
|
from .utils import add_swagger_summaries |
|
|
|
|
from django.views import View |
|
|
|
|
import ollama |
|
|
|
|
from asgiref.sync import sync_to_async |
|
|
|
|
from django.utils.decorators import method_decorator |
|
|
|
|
import asyncio |
|
|
|
|
|
|
|
|
|
import asyncio, base64, json |
|
|
|
|
from django.views.decorators.csrf import csrf_exempt |
|
|
|
|
|
|
|
|
|
User = get_user_model() |
|
|
|
@ -35,7 +31,7 @@ class ConversationView(viewsets.ModelViewSet): |
|
|
|
|
|
|
|
|
|
queryset = Conversation.objects.all() |
|
|
|
|
serializer_class = ConversationSerializer |
|
|
|
|
permission_classes = [IsAuthenticated] # JWT token auth required |
|
|
|
|
permission_classes = [AllowAny] # JWT token auth required |
|
|
|
|
|
|
|
|
|
def get_queryset(self): |
|
|
|
|
queryset = Conversation.objects.all() |
|
|
|
@ -64,16 +60,14 @@ class ConversationView(viewsets.ModelViewSet): |
|
|
|
|
] |
|
|
|
|
) |
|
|
|
|
@action(detail=True, methods=['get']) |
|
|
|
|
def contents(self, request, pk=None): |
|
|
|
|
async def contents(self, request, pk=None): |
|
|
|
|
conversation = self.get_object() |
|
|
|
|
messages = conversation.messages.all() |
|
|
|
|
return Response(data=list(messages.values())) |
|
|
|
|
|
|
|
|
|
@method_decorator(csrf_exempt, name='dispatch') |
|
|
|
|
class ConversationActions(View): |
|
|
|
|
@swagger_auto_schema( |
|
|
|
|
operation_description="Discutes with the ai", |
|
|
|
|
operation_summary="Make a new prompt", |
|
|
|
|
operation_summary="Make a prompt", |
|
|
|
|
request_body=openapi.Schema( |
|
|
|
|
type=openapi.TYPE_OBJECT, |
|
|
|
|
properties={ |
|
|
|
@ -82,14 +76,17 @@ class ConversationActions(View): |
|
|
|
|
required=['content'] |
|
|
|
|
), |
|
|
|
|
responses={ |
|
|
|
|
200: openapi.Response('Message processed successfully', MessageSerializer), |
|
|
|
|
200: openapi.Response( |
|
|
|
|
description="Text stream response", |
|
|
|
|
schema=openapi.Schema(type=openapi.TYPE_STRING, format='binary') |
|
|
|
|
), |
|
|
|
|
400: 'Bad Request', |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
tags=['conversations'] |
|
|
|
|
) |
|
|
|
|
@sync_to_async |
|
|
|
|
def post(self, request, *args, **kwargs): |
|
|
|
|
conversation = Conversation.objects.get(pk=self.kwargs['id']) |
|
|
|
|
data = json.loads(request.body) |
|
|
|
|
@action(detail=True, methods=['post']) |
|
|
|
|
def prompt(self, request, pk=None): |
|
|
|
|
conversation = self.get_object() |
|
|
|
|
messages = []#{ |
|
|
|
|
# "role": "system", |
|
|
|
|
# "content": """ |
|
|
|
@ -114,24 +111,24 @@ class ConversationActions(View): |
|
|
|
|
}) |
|
|
|
|
messages.append({ |
|
|
|
|
"role": "user", |
|
|
|
|
"content": data['content'] or "" |
|
|
|
|
"content": request.data.get('content', '') |
|
|
|
|
}) |
|
|
|
|
Message(role="user", content=data['content'] or "", conversation=conversation).save() |
|
|
|
|
Message(role="user", content=request.data.get('content', '') or "", conversation=conversation).save() |
|
|
|
|
ai_message = Message(role="assistant", content="", conversation=conversation) |
|
|
|
|
|
|
|
|
|
@sync_to_async |
|
|
|
|
def save_message(): |
|
|
|
|
ai_message.save() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def chat_event_stream(): |
|
|
|
|
message = "" |
|
|
|
|
try: |
|
|
|
|
stream = ollama.chat(model="llama2-uncensored", messages=messages, stream=True) |
|
|
|
|
stream = ollama.chat(model="gemma3:12b", messages=messages, stream=True) |
|
|
|
|
for chunk in stream: |
|
|
|
|
message = chunk['message']['content'] |
|
|
|
|
#print(message, base64.b64encode(message.encode("utf-8"))) |
|
|
|
|
ai_message.content += message |
|
|
|
|
yield f"{base64.b64encode(message.encode("utf-8")).decode("utf-8")}" |
|
|
|
|
yield f"{message}" |
|
|
|
|
finally: |
|
|
|
|
await save_message() |
|
|
|
|
response = StreamingHttpResponse(chat_event_stream(), content_type='text/event-stream') |
|
|
|
@ -139,6 +136,11 @@ class ConversationActions(View): |
|
|
|
|
response['X-Accel-Buffering'] = 'no' |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MasterKeyTokenObtainView(TokenObtainPairView): |
|
|
|
|