RabbitMQに登録したメッセージからstable-diffusionを実行

フォルダの作成とdocker-composeの変更(WSL作業)

  1. 前回 の環境を変更していく
    画像input,outputフォルダの作成
work
├── Dockerfile
├── docker-compose.yml
├── src
│   ├── sample.py
│   ├── send_img2img.py
│   └── send_txt2img.py
├── inimg
├── outimg
│
 mkdir ./inimg
 mkdir ./outimg
  1. docker-composeの編集
    RabbitMQのコンテナの追加と画像フォルダの追加
    ローカルパス(\wsl$\Ubuntu-20.04\home\ubuntu)でWSL上のフォルダにアクセスできるため画像の置き換えが容易にできるようになる
    また、エクスプローラの特大アイコンで出力した画像が確認できるようになる
version: '3.6'
services:
  cuda:
    build:
      context: .
      dockerfile: Dockerfile
    runtime: nvidia
    environment:
      - NVIDIA_VISIBLE_DEVICES=all
    tty: true
    volumes:
      - ./src:/src
      - ./inimg:/inimg
      - ./outimg:/outimg
    working_dir: /src
  rabbitmq:
      image: rabbitmq:3.8.14-management
      container_name: 'rabbitmq'
      restart: always
      ports:
          - 5672:5672
          - 15672:15672
      tty: true
  • 画像の削除などができなくなった場合は所有者を変更する
    (ubuntu:ubuntuはWSLの設定に合わせて変更する)
 sudo chown  ubuntu:ubuntu /inimg -R
 sudo chown  ubuntu:ubuntu /outimg -R

RabbitMQからメッセージを取得して実行するスクリプト(docker内作業)

  1. CUDA側のdockerコンテナにアタッチしてMQアクセス用のライブラリを取得しておく
 pip install pika
  1. ソースファイルの変更と作成
    sample.pyを以下のように変更する
# pip install pika

import os
import pika
workdir = "/work/stable-diffusion"

pika_param = pika.ConnectionParameters(host='rabbitmq', heartbeat=0)
connection = pika.BlockingConnection(pika_param)
channel = connection.channel()

channel.queue_declare(queue='txt2img')
channel.queue_declare(queue='img2img')

def img2img_callback(ch, method, properties, body):
  exeimg2img = " ".join(
      ["python"
    , "optimizedSD/optimized_img2img.py"])
  optimg2img = {
      "--prompt" : "Medium shot, alone, an anime girl"
    , "--init-img" : "/inimg/test.jpg"
    , "--outdir" : "/outimg"
    , "--strength" : "0.6"
    , "--n_samples" : "10"
    , "--H" : "512"
    , "--W" : "512"}

  if properties.headers is not None:
    # ヘッダの内容でオプションを置き換え、なければ作成
    for key in properties.headers:
      optimg2img[key] = properties.headers[key]
  # プロンプトだけは後でbodyで置き換え
  optimg2img["--prompt"] = "{}".format(str(body, 'utf-8'))

  exeimg2img = exeimg2img + " " + \
        " ".join("{} '{}'".format( \
            key, optimg2img[key]) for key in optimg2img)
  ch.basic_ack(delivery_tag = method.delivery_tag)

  print("img2img Start {}".format(str(body, 'utf-8')))
  os.system(exeimg2img)
  print("img2img End")


def txt2img_callback(ch, method, properties, body):
  exetxt2img = " ".join(
    ["python"
  , "optimizedSD/optimized_txt2img.py"])
  opttxt2img = {
    "--prompt" : ""
  , "--outdir" : "/outimg"
  , "--seed" : "27"
  , "--n_samples" : "10"
  , "--ddim_steps" : "50"
  , "--H" : "512"
  , "--W" : "512"}

  if properties.headers is not None:
    # ヘッダの内容でオプションを置き換え、なければ作成
    for key in properties.headers:
      opttxt2img[key] = properties.headers[key]
  # プロンプトだけは後でbodyで置き換え
  opttxt2img["--prompt"] = "{}".format(str(body, 'utf-8'))

  exetxt2img = exetxt2img + " " + \
        " ".join("{} '{}'".format( \
            key, opttxt2img[key]) for key in opttxt2img)
  ch.basic_ack(delivery_tag = method.delivery_tag)
  print("txt2img Start {}".format(str(body, 'utf-8')))
  os.system(exetxt2img)
  print("txt2img End")

channel.basic_consume(
    queue='txt2img', on_message_callback=txt2img_callback)
channel.basic_consume(
    queue='img2img', on_message_callback=img2img_callback)

os.chdir(workdir)
channel.start_consuming()
  1. send_img2img.pyを作成する
# pip install pika

from wsgiref import headers
import pika

pika_param = pika.ConnectionParameters(host='rabbitmq')
connection = pika.BlockingConnection(pika_param)
channel = connection.channel()

channel.queue_declare(queue='img2img')

opt=pika.BasicProperties(headers={'--init-img': '/inimg/hogehoge.png'})
channel.basic_publish(exchange='', routing_key='img2img', body='Anime, full body', properties=opt)

connection.close()
  • hogehoge.pngは予め用意しておく

  • send_txt2img.pyを作成する

# pip install pika

import pika

pika_param = pika.ConnectionParameters(host='rabbitmq')
connection = pika.BlockingConnection(pika_param)
channel = connection.channel()

channel.queue_declare(queue='txt2img')

channel.basic_publish(exchange='', routing_key='txt2img', body='Anime, full body')

connection.close()
  1. sample.pyを実行するとMQの監視が開始する
 python3 sample.py

メッセージの登録

Publish Message