Replace View by adrf View

master
Anulax ago%!(EXTRA string=1 month)
parent 5c3cb5344a
commit f2ff95562a
  1. 15
      api/management/commands/create_conversations.py
  2. 21
      api/migrations/0008_alter_conversation_user.py
  3. 2
      api/models.py
  4. 3
      api/urls.py
  5. 54
      api/views.py
  6. 11
      app/templates/index.html
  7. 1
      botzilla/settings.py
  8. 4
      static/drf-yasg/style.css

@ -0,0 +1,15 @@
from django.core.management.base import BaseCommand
from api.models import Conversation
class Command(BaseCommand):
help = 'Create N test conversations'
def add_arguments(self, parser):
parser.add_argument('--number', type=int, help='Number of conversations to create')
def handle(self, *args, **kwargs):
number = kwargs.get('number') or 1
for i in range(number):
conv = Conversation.objects.create(title="test-conversation")
self.stdout.write(self.style.SUCCESS(f'Successfully created test conversation (id : {conv.id})'))

@ -0,0 +1,21 @@
# Generated by Django 5.2.1 on 2025-05-19 09:49
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('api', '0007_conversation_user'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterField(
model_name='conversation',
name='user',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='conversations', to=settings.AUTH_USER_MODEL),
),
]

@ -37,7 +37,7 @@ class AiModel(models.Model):
class Conversation(models.Model): class Conversation(models.Model):
title = models.CharField(max_length=255) title = models.CharField(max_length=255)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="conversations") user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="conversations", null=True)
class Message(models.Model): class Message(models.Model):
content = models.TextField() content = models.TextField()

@ -22,7 +22,6 @@ from rest_framework_simplejwt.views import (
) )
from .views import MasterKeyTokenObtainView from .views import MasterKeyTokenObtainView
from .views import ConversationView from .views import ConversationView
from .views import ConversationActions
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
router = DefaultRouter() router = DefaultRouter()
@ -31,6 +30,6 @@ urlpatterns = [
path('token/', MasterKeyTokenObtainView.as_view(), name='token_obtain_pair'), path('token/', MasterKeyTokenObtainView.as_view(), name='token_obtain_pair'),
path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
path('token/verify/', TokenVerifyView.as_view(), name='token_verify'), path('token/verify/', TokenVerifyView.as_view(), name='token_verify'),
path('conversations/<int:id>/prompt/', ConversationActions.as_view()), #path('conversations/<int:id>/prompt/', conversation_prompt),
path('', include(router.urls)), path('', include(router.urls)),
] ]

@ -1,31 +1,27 @@
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi from drf_yasg import openapi
import base64
import json
from rest_framework_simplejwt.views import TokenObtainPairView from rest_framework_simplejwt.views import TokenObtainPairView
from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.tokens import RefreshToken
from django.views.decorators.http import require_http_methods
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework import status
from rest_framework.decorators import action, api_view
from adrf import viewsets
from django.views.decorators.http import require_POST
from django.utils import timezone from django.utils import timezone
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from .models import MasterKey from .models import MasterKey, Conversation
from .models import Conversation
from django.http import StreamingHttpResponse from django.http import StreamingHttpResponse
from .models import Message from .models import Message
from .serializers import ConversationSerializer, MessageSerializer from .serializers import ConversationSerializer, MessageSerializer
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from .utils import add_swagger_summaries from .utils import add_swagger_summaries
from django.views import View from django.views import View
import ollama import ollama
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
import asyncio import asyncio, base64, json
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
User = get_user_model() User = get_user_model()
@ -35,7 +31,7 @@ class ConversationView(viewsets.ModelViewSet):
queryset = Conversation.objects.all() queryset = Conversation.objects.all()
serializer_class = ConversationSerializer serializer_class = ConversationSerializer
permission_classes = [IsAuthenticated] # JWT token auth required permission_classes = [AllowAny] # JWT token auth required
def get_queryset(self): def get_queryset(self):
queryset = Conversation.objects.all() queryset = Conversation.objects.all()
@ -64,16 +60,14 @@ class ConversationView(viewsets.ModelViewSet):
] ]
) )
@action(detail=True, methods=['get']) @action(detail=True, methods=['get'])
def contents(self, request, pk=None): async def contents(self, request, pk=None):
conversation = self.get_object() conversation = self.get_object()
messages = conversation.messages.all() messages = conversation.messages.all()
return Response(data=list(messages.values())) return Response(data=list(messages.values()))
@method_decorator(csrf_exempt, name='dispatch')
class ConversationActions(View):
@swagger_auto_schema( @swagger_auto_schema(
operation_description="Discutes with the ai", operation_description="Discutes with the ai",
operation_summary="Make a new prompt", operation_summary="Make a prompt",
request_body=openapi.Schema( request_body=openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
properties={ properties={
@ -82,14 +76,17 @@ class ConversationActions(View):
required=['content'] required=['content']
), ),
responses={ responses={
200: openapi.Response('Message processed successfully', MessageSerializer), 200: openapi.Response(
description="Text stream response",
schema=openapi.Schema(type=openapi.TYPE_STRING, format='binary')
),
400: 'Bad Request', 400: 'Bad Request',
} },
tags=['conversations']
) )
@sync_to_async @action(detail=True, methods=['post'])
def post(self, request, *args, **kwargs): def prompt(self, request, pk=None):
conversation = Conversation.objects.get(pk=self.kwargs['id']) conversation = self.get_object()
data = json.loads(request.body)
messages = []#{ messages = []#{
# "role": "system", # "role": "system",
# "content": """ # "content": """
@ -114,9 +111,9 @@ class ConversationActions(View):
}) })
messages.append({ messages.append({
"role": "user", "role": "user",
"content": data['content'] or "" "content": request.data.get('content', '')
}) })
Message(role="user", content=data['content'] or "", conversation=conversation).save() Message(role="user", content=request.data.get('content', '') or "", conversation=conversation).save()
ai_message = Message(role="assistant", content="", conversation=conversation) ai_message = Message(role="assistant", content="", conversation=conversation)
@sync_to_async @sync_to_async
@ -126,12 +123,12 @@ class ConversationActions(View):
async def chat_event_stream(): async def chat_event_stream():
message = "" message = ""
try: try:
stream = ollama.chat(model="llama2-uncensored", messages=messages, stream=True) stream = ollama.chat(model="gemma3:12b", messages=messages, stream=True)
for chunk in stream: for chunk in stream:
message = chunk['message']['content'] message = chunk['message']['content']
#print(message, base64.b64encode(message.encode("utf-8"))) #print(message, base64.b64encode(message.encode("utf-8")))
ai_message.content += message ai_message.content += message
yield f"{base64.b64encode(message.encode("utf-8")).decode("utf-8")}" yield f"{message}"
finally: finally:
await save_message() await save_message()
response = StreamingHttpResponse(chat_event_stream(), content_type='text/event-stream') response = StreamingHttpResponse(chat_event_stream(), content_type='text/event-stream')
@ -141,6 +138,11 @@ class ConversationActions(View):
class MasterKeyTokenObtainView(TokenObtainPairView): class MasterKeyTokenObtainView(TokenObtainPairView):
permission_classes = [AllowAny] permission_classes = [AllowAny]

@ -31,10 +31,6 @@
margin-top: 10px; margin-top: 10px;
border-radius: 5px; border-radius: 5px;
} }
input {
padding: 8px;
width: 70%;
}
button { button {
padding: 8px 16px; padding: 8px 16px;
background-color: #4CAF50; background-color: #4CAF50;
@ -50,7 +46,7 @@
</head> </head>
<body> <body>
<h1>Simple Chat App</h1> <h1>Simple Chat App</h1>
<input id="conversation-id" type="number" value="1">
<div id="chat-container"> <div id="chat-container">
<h3>Enter your message:</h3> <h3>Enter your message:</h3>
<textarea type="text" id="prompt-input" placeholder="Type your message..." value="hello"> <textarea type="text" id="prompt-input" placeholder="Type your message..." value="hello">
@ -64,6 +60,7 @@
<script> <script>
const responseText = document.querySelector('#response-text'); const responseText = document.querySelector('#response-text');
const convId = document.querySelector('#conversation-id');
// Function to send the prompt and handle streaming response // Function to send the prompt and handle streaming response
responseText.innerHTML = marked.parse(""); responseText.innerHTML = marked.parse("");
async function sendPrompt() { async function sendPrompt() {
@ -79,7 +76,7 @@
} }
try { try {
// Make the fetch request // Make the fetch request
const response = await fetch('/api/conversations/14/prompt/', { const response = await fetch('/api/conversations/' + convId.value + '/prompt/', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -95,7 +92,7 @@
reader.read().then(({ done, value }) => { reader.read().then(({ done, value }) => {
if (done) return; if (done) return;
console.log(value); console.log(value);
buffer += atob((new TextDecoder()).decode(value, { stream: true })) buffer += (new TextDecoder()).decode(value, { stream: true })
console.log(buffer) console.log(buffer)
responseText.innerHTML = marked.parse(buffer); responseText.innerHTML = marked.parse(buffer);
readChunk(); readChunk();

@ -39,6 +39,7 @@ INSTALLED_APPS = [
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework', 'rest_framework',
'adrf',
'rest_framework_simplejwt', 'rest_framework_simplejwt',
'rest_framework_simplejwt.token_blacklist', 'rest_framework_simplejwt.token_blacklist',
'api', 'api',

@ -1,4 +0,0 @@
img {
max-height: 20px;
width: auto;
}
Loading…
Cancel
Save