Esta entrada del blog explica cómo implementar operaciones personalizadas en PyTorch utilizando C++ y CUDA, y cómo integrarlas tanto en modelos de PyTorch como en programas de inferencia compilados con AOTInductor. El ejemplo guía es una convolución identidad mínima que ilustra el ciclo completo: desde la definición en C++/CUDA hasta su uso en Python y su exportación con torch.export.
La primera parte aborda las funciones personalizadas sin estado, registradas mediante la macro TORCH_LIBRARY_IMPL. El autor muestra cómo proporcionar implementaciones separadas para CPU (un clone() elemento a elemento) y CUDA (un kernel que respeta la forma, el tipo de datos y los strides del tensor de entrada). La macro gestiona el envío automático a la implementación adecuada según el dispositivo del tensor. Esquema, kernel CUDA y registro de la operación conviven en un mismo archivo, lo que facilita la compilación en una biblioteca compartida (libidentity_conv_ops.so).
La segunda parte cubre las clases personalizadas con estado —capaces de almacenar parámetros— mediante torch::CustomClassHolder y la macro TORCH_LIBRARY. La clase IdentityConvClass se expone a Python como torch.classes.my_ops.IdentityConvClass, con constructor, método forward, descriptor channels, el protocolo de pytree (obj_flatten/obj_unflatten) que requiere torch.export y serialización TorchScript mediante def_pickle. La biblioteca se compila sin dependencia de pybind11, de modo que un binario C++ puro puede cargarla vía dlopen.
La tercera parte describe la integración en Python: la biblioteca se carga con torch.ops.load_library y, para que torch.compile y torch.export operen con FakeTensor, se registran versiones abstractas o “fake” mediante @register_fake_class y @torch.library.register_fake. Se detallan los métodos obj_flatten y obj_unflatten que la clase falsa debe implementar para que el trazado simbólico recorra correctamente los atributos del módulo que contienen una instancia de la clase C++.
El artículo está dirigido a desarrolladores que necesitan extender PyTorch con kernels CUDA optimizados, ya sea para acelerar cargas de trabajo concretas o para desplegar modelos compilados con AOTInductor. No trata la optimización interna del kernel ni la depuración, sino la infraestructura de registro y la compatibilidad con el resto de la cadena de herramientas.
