Enabling & Exploring Stable Defussion – Part 2

As we’ve started explaining, the importance & usage of Stable Defussion in our previous post:

Enabling & Exploring Stable Defussion – Part 1

In today’s post, we’ll discuss another approach, where we built the custom Python-based SDK solution that consumes HuggingFace Library, which generates video out of the supplied prompt.

But, before that, let us view the demo generated from a custom solution.

Isn’t it exciting? Let us dive deep into the details.


Let us understand basic flow of events for the custom solution –

So, the application will interact with the python-sdk like “stable-diffusion-3.5-large” & “dreamshaper-xl-1-0”, which is available in HuggingFace. As part of the process, these libraries will load all the large models inside the local laptop that require some time depend upon the bandwidth of your internet.

Before we even deep dive into the code, let us understand the flow of Python scripts as shown below:

From the above diagram, we can understand that the main application will be triggered by “generateText2Video.py”. As you can see that “clsConfigClient.py” has all the necessary parameter information that will be supplied to all the scripts.

“generateText2Video.py” will trigger the main class named “clsText2Video.py”, which then calls all the subsequent classes.

Great! Since we now have better visibility of the script flow, let’s examine the key snippets individually.


class clsText2Video:
    def __init__(self, model_id_1, model_id_2, output_path, filename, vidfilename, fps, force_cpu=False):
        self.model_id_1 = model_id_1
        self.model_id_2 = model_id_2
        self.output_path = output_path
        self.filename = filename
        self.vidfilename = vidfilename
        self.force_cpu = force_cpu
        self.fps = fps

        # Initialize in main process
        os.environ["TOKENIZERS_PARALLELISM"] = "true"
        self.r1 = cm.clsMaster(force_cpu)
        self.torch_type = self.r1.getTorchType()
        
        torch.mps.empty_cache()
        self.pipe = self.r1.getText2ImagePipe(self.model_id_1, self.torch_type)
        self.pipeline = self.r1.getImage2VideoPipe(self.model_id_2, self.torch_type)

        self.text2img = cti.clsText2Image(self.pipe, self.output_path, self.filename)
        self.img2vid = civ.clsImage2Video(self.pipeline)

    def getPrompt2Video(self, prompt):
        try:
            input_image = self.output_path + self.filename
            target_video = self.output_path + self.vidfilename

            if self.text2img.genImage(prompt) == 0:
                print('Pass 1: Text to intermediate images generated!')
                
                if self.img2vid.genVideo(prompt, input_image, target_video, self.fps) == 0:
                    print('Pass 2: Successfully generated!')
                    return 0
            return 1
        except Exception as e:
            print(f"\nAn unexpected error occurred: {str(e)}")
            return 1

Now, let us interpret:

This is the initialization method for the class. It does the following:

  • Sets up configurations like model IDs, output paths, filenames, video filename, frames per second (fps), and whether to use the CPU (force_cpu).
  • Configures an environment variable for tokenizer parallelism.
  • Initializes helper classes (clsMaster) to manage system resources and retrieve appropriate PyTorch settings.
  • Creates two pipelines:
    • pipe: For converting text to images using the first model.
    • pipeline: For converting images to video using the second model.
  • Initializes text2img and img2vid objects:
    • text2img handles text-to-image conversions.
    • img2vid handles image-to-video conversions.

This method generates a video from a text prompt in two steps:

  1. Text-to-Image Conversion:
    • Calls genImage(prompt) using the text2img object to create an intermediate image file.
    • If successful, it prints confirmation.
  2. Image-to-Video Conversion:
    • Uses the img2vid object to convert the intermediate image into a video file.
    • Includes the input image path, target video path, and frames per second (fps).
    • If successful, it prints confirmation.
  • If either step fails, the method returns 1.
  • Logs any unexpected errors and returns 1 in such cases.
# Set device for Apple Silicon GPU
def setup_gpu(force_cpu=False):
    if not force_cpu and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        print('Running on Apple Silicon MPS GPU!')
        return torch.device("mps")
    return torch.device("cpu")

######################################
####         Global Flag      ########
######################################

class clsMaster:
    def __init__(self, force_cpu=False):
        self.device = setup_gpu(force_cpu)

    def getTorchType(self):
        try:
            # Check if MPS (Apple Silicon GPU) is available
            if not torch.backends.mps.is_available():
                torch_dtype = torch.float32
                raise RuntimeError("MPS (Metal Performance Shaders) is not available on this system.")
            else:
                torch_dtype = torch.float16
            
            return torch_dtype
        except Exception as e:
            torch_dtype = torch.float16
            print(f'Error: {str(e)}')

            return torch_dtype

    def getText2ImagePipe(self, model_id, torchType):
        try:
            device = self.device

            torch.mps.empty_cache()
            self.pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torchType, use_safetensors=True, variant="fp16",).to(device)

            return self.pipe
        except Exception as e:
            x = str(e)
            print('Error: ', x)

            torch.mps.empty_cache()
            self.pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torchType,).to(device)

            return self.pipe
        
    def getImage2VideoPipe(self, model_id, torchType):
        try:
            device = self.device

            torch.mps.empty_cache()
            self.pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torchType, use_safetensors=True, use_fast=True).to(device)

            return self.pipeline
        except Exception as e:
            x = str(e)
            print('Error: ', x)

            torch.mps.empty_cache()
            self.pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torchType).to(device)

            return self.pipeline

Let us interpret:

This function determines whether to use the Apple Silicon GPU (MPS) or the CPU:

  • If force_cpu is False and the MPS GPU is available, it sets the device to “mps” (Apple GPU) and prints a message.
  • Otherwise, it defaults to the CPU.

This is the initializer for the clsMaster class:

  • It sets the device to either GPU or CPU using the setup_gpu function (mentioned above) based on the force_cpu flag.

This method determines the PyTorch data type to use:

  • Checks if MPS GPU is available:
    • If available, uses torch.float16 for optimized performance.
    • If unavailable, defaults to torch.float32 and raises a warning.
  • Handles errors gracefully by defaulting to torch.float16 and printing the error.

This method initializes a text-to-image pipeline:

  • Loads the Stable Diffusion model with the given model_id and torchType.
  • Configures it for MPS GPU or CPU, based on the device.
  • Clears the GPU cache before loading the model to optimize memory usage.
  • If an error occurs, attempts to reload the pipeline without safetensors.

This method initializes an image-to-video pipeline:

  • Similar to getText2ImagePipe, it loads the Stable Diffusion XL Img2Img pipeline with the specified model_id and torchType.
  • Configures it for MPS GPU or CPU and clears the cache before loading.
  • On error, reloads the pipeline without additional optimization settings and prints the error.

Let us continue this in the next post:

Enabling & Exploring Stable Defussion – Part 3

Till then, Happy Avenging! 🙂

2 thoughts on “Enabling & Exploring Stable Defussion – Part 2

Leave a Reply