Skip to content

Commit bd1f654

Browse files
committed
Use up-to-date loader.
1 parent 867250e commit bd1f654

1 file changed

Lines changed: 139 additions & 57 deletions

File tree

plugins/OIIO/OIIOLoader.py

Lines changed: 139 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
logger.error("Required modules not found. Please install OpenImageIO and numpy.")
2828
raise
2929

30-
3130
have_matplot = False
3231
try:
3332
import matplotlib.pyplot as plt
@@ -62,19 +61,48 @@ def __init__(self):
6261
self._extensions.update(ext.strip() for ext in group.split(","))
6362
logger.debug(f"Cache supported extensions: {self._extensions}")
6463

64+
self.preview = False
65+
self.identifier = "OpenImageIO Custom Image Loader"
66+
self.color_space = {}
67+
6568
def supportedExtensions(self):
6669
"""
67-
Return a set of supported image file extensions.
70+
Derived method to return a set of supported image file extensions.
6871
"""
69-
logger.info(f"Supported OIIO supported extensions: {self._extensions}")
72+
logger.info(f"OIIO supported extensions: {self._extensions}")
7073
return self._extensions
7174

72-
def previewImage(self, data, width, height, nchannels):
75+
def set_preview(self, value):
76+
"""
77+
Set whether to preview images when loading and saving
78+
79+
@param value: Boolean indicating whether to enable preview
80+
"""
81+
self.preview = value
82+
83+
def get_identifier(self):
84+
return "OIIO Custom Loader"
85+
86+
def previewImage(self, title, data, width, height, nchannels, color_space):
7387
"""
7488
Utility method to preview an image using matplotlib.
7589
Handles normalization and dtype for correct display.
90+
91+
@param title: Title for the preview window
92+
@param data: Image data array
93+
@param width: Image width
94+
@param height: Image height
95+
@param nchannels: Number of image channels
96+
@param color_space: Color space of the image
7697
"""
98+
if not self.preview:
99+
return
100+
77101
if have_matplot:
102+
# If the image is float16 (half), convert to float32
103+
if data.dtype == np.float16:
104+
data = data.astype(np.float32)
105+
78106
flat = data.reshape(height, width, nchannels)
79107
# Always display as RGB (first 3 channels or repeat if less)
80108
if nchannels >= 3:
@@ -99,19 +127,23 @@ def previewImage(self, data, width, height, nchannels):
99127
else:
100128
rgb_disp = rgb
101129

102-
plt.imshow(rgb_disp)
103-
plt.axis('off')
130+
# Set title bar text for the preview window
131+
fig, ax = plt.subplots()
132+
ax.imshow(rgb_disp)
133+
ax.axis("off")
134+
#fig.patch.set_facecolor("black")
135+
fig.canvas.manager.set_window_title(title)
136+
info = f"Dimensions:({width}x{height}), {nchannels} channels, type={data.dtype}, colorspace={color_space}"
137+
fig.suptitle(title, fontsize=12)
138+
plt.title(info, fontsize=9)
104139
plt.show()
105140

106141
def loadImage(self, filePath):
107142
"""
108143
Load an image from the file system (MaterialX interface method).
109-
110-
Args:
111-
filePath (MaterialX.FilePath): Path to the image file
112-
113-
Returns:
114-
MaterialX.ImagePtr: MaterialX Image object or None if loading fails
144+
145+
@param filePath (MaterialX.FilePath): Path to the image file
146+
@returns MaterialX.ImagePtr: MaterialX Image object or None if loading fails
115147
"""
116148
file_path_str = filePath.asString()
117149
logger.info(f"Load using OIIO loader: {file_path_str}")
@@ -129,8 +161,14 @@ def loadImage(self, filePath):
129161

130162
# Get image specifications
131163
spec = img_input.spec()
132-
self.last_spec = spec
133-
self.last_loaded_path = file_path_str
164+
color_space = spec.getattribute("oiio:ColorSpace")
165+
logger.info(f"ColorSpace: {color_space}")
166+
self.color_space[file_path_str] = color_space
167+
168+
# Check channel count
169+
channels = spec.nchannels
170+
if channels > 4:
171+
channels = 4
134172

135173
# Determine MaterialX base type from OIIO format
136174
base_type = self._oiio_to_materialx_type(spec.format.basetype)
@@ -140,25 +178,28 @@ def loadImage(self, filePath):
140178
return None
141179

142180
# Create MaterialX image
143-
mx_image = mx_render.Image.create(spec.width, spec.height, spec.nchannels, base_type)
181+
mx_image = mx_render.Image.create(spec.width, spec.height, channels, base_type)
144182
mx_image.createResourceBuffer()
145-
logger.debug(f"Create buffer with width: {spec.width}, height: {spec.height}, channels: {spec.nchannels}")
183+
logger.debug(f"Create buffer with width: {spec.width}, height: {spec.height}, channels: {spec.nchannels} -> {channels}")
146184

147185
# Read the image data using the correct OIIO Python API (returns a bytes object)
148186
logger.debug(f"Reading image data from '{file_path_str}' with spec: {spec}")
149-
data = img_input.read_image(0, 0, 0, spec.nchannels, spec.format)
187+
data = mx_image.getResourceBuffer()
188+
data = img_input.read_image(0, 0, 0, channels, spec.format)
150189
if len(data) > 0:
151190
logger.debug(f"Done Reading image data from '{file_path_str}' with spec: {spec}")
152191
else:
153192
logger.error(f"Could not read image data.")
154193
return None
155194

195+
self.previewImage("Loaded MaterialX Image", data, spec.width, spec.height, channels, color_space)
196+
156197
# Steps:
157198
# - Copy the OIIO data into the MaterialX image resource buffer
158199
resource_buffer_ptr = mx_image.getResourceBuffer()
159200
bytes_per_channel = spec.format.size()
160-
total_bytes = spec.width * spec.height * spec.nchannels * bytes_per_channel
161-
logger.info(f"Total bytes read in: {total_bytes} (width: {spec.width}, height: {spec.height}, channels: {spec.nchannels}, format: {spec.format})")
201+
total_bytes = spec.width * spec.height * channels * bytes_per_channel
202+
logger.info(f"Total bytes read in: {total_bytes} (width: {spec.width}, height: {spec.height}, channels: {channels}, format: {spec.format})")
162203
try:
163204
ctypes.memmove(resource_buffer_ptr, (ctypes.c_char * total_bytes).from_buffer_copy(data), total_bytes)
164205
except Exception as e:
@@ -186,19 +227,26 @@ def saveImage(self, filePath, image, verticalFlip=False):
186227
filename = filePath.asString()
187228
width = image.getWidth()
188229
height = image.getHeight()
189-
channels = image.getChannelCount()
230+
231+
# Clamp to RGBA
232+
src_channels = image.getChannelCount()
233+
channels = min(src_channels, 4)
234+
if src_channels > 4:
235+
logger.warning(f"Image has {src_channels} channels. Saving only first {channels} (RGBA).")
236+
190237
mx_basetype = image.getBaseType()
191238
oiio_format = self._materialx_to_oiio_type(mx_basetype)
192239
logger.info(f"mx_basetype: {mx_basetype}, oiio_format: {oiio_format}, base_stride: {image.getBaseStride()}")
193240
if oiio_format is None:
194-
logger.error(f"Error: Unsupported MaterialX base type for OIIO: {mx_basetype}")
241+
logger.error(f"Unsupported MaterialX base type for OIIO: {mx_basetype}")
195242
return False
196-
197-
buffer = image.getResourceBuffer()
198-
243+
244+
buffer_addr = image.getResourceBuffer()
199245
np_type = self._materialx_type_to_np_type(mx_basetype)
200-
pixels = np.zeros((height, width, channels), dtype=np_type)
201-
# Copy from buffer to pixels
246+
if np_type is None:
247+
logger.error(f"No NumPy dtype mapping for base type: {mx_basetype}")
248+
return False
249+
202250
try:
203251
# Steps:
204252
# - Maps the MaterialX base type to OIIO and NumPy types.
@@ -207,37 +255,64 @@ def saveImage(self, filePath, image, verticalFlip=False):
207255
# - Optionally previews the image for debugging.
208256
# - Creates an OIIO ImageOutput and writes the image to disk.
209257
#
210-
base_stride = image.getBaseStride()
211-
total_bytes = width * height * channels * base_stride
258+
base_stride = image.getBaseStride() # bytes per channel element
259+
total_bytes = width * height * src_channels * base_stride
260+
212261
buf_type = (ctypes.c_char * total_bytes)
213-
buf = buf_type.from_address(buffer)
214-
np_buffer = np.frombuffer(buf, dtype=np_type).reshape((height, width, channels))
215-
np.copyto(pixels, np_buffer)
262+
buf = buf_type.from_address(buffer_addr)
263+
264+
np_buffer = np.frombuffer(buf, dtype=np_type)
265+
266+
# Validate total elements before reshape to catch mismatches early
267+
expected_elems = width * height * src_channels
268+
if np_buffer.size != expected_elems:
269+
logger.error(f"Buffer element count mismatch: got {np_buffer.size}, expected {expected_elems}.")
270+
return False
271+
272+
np_buffer = np_buffer.reshape((height, width, src_channels))
273+
274+
# Keep only up to RGBA
275+
pixels = np_buffer[..., :channels].copy()
216276

217-
# Handle vertical flip if requested
218277
if verticalFlip:
219278
logger.info("Applying vertical flip before saving image.")
220279
pixels = np.flipud(pixels)
221280

222281
logger.info("Previewing image after load into Image and reload for save...")
223-
self.previewImage(pixels, width, height, channels)
282+
# Remove "saved_" prefix if present
283+
search_name = filename.replace("saved_", "")
284+
color_space = "Unknown"
285+
for key in self.color_space:
286+
value = self.color_space[key]
287+
path = os.path.basename(key)
288+
if path in search_name:
289+
color_space = value
290+
logger.info(f"colorspace lookup for: {search_name}. list: {color_space}")
291+
self.previewImage("OpenImageIO Output Image", pixels, width, height, channels, color_space)
224292

225293
except Exception as e:
226294
logger.error(f"Error copying buffer to pixels: {e}")
227295
return False
228-
296+
229297
out = oiio.ImageOutput.create(filename)
230-
if out:
231-
if np_type is None:
232-
logger.error(f"Error: Unsupported NumPy type for OIIO: {mx_basetype}")
233-
return False
234-
spec = oiio.ImageSpec(width, height, channels, np_type)
298+
if not out:
299+
logger.error("Failed to create OIIO ImageOutput.")
300+
return False
301+
302+
try:
303+
spec = oiio.ImageSpec(width, height, channels, oiio_format)
235304
out.open(filename, spec)
236305
out.write_image(pixels)
237-
logger.info(f"Image saved to {filename} with width: {width}, height: {height}, channels: {channels}, base type: {mx_basetype}")
306+
logger.info(f"Image saved to {filename} (w={width}, h={height}, c={channels}, type={mx_basetype})")
238307
out.close()
239308
return True
240-
return False
309+
except Exception as e:
310+
logger.error(f"Failed to write image: {e}")
311+
try:
312+
out.close()
313+
finally:
314+
pass
315+
return False
241316

242317
def _oiio_to_materialx_type(self, oiio_basetype):
243318
"""Convert OIIO base type to MaterialX Image base type."""
@@ -268,14 +343,14 @@ def _materialx_to_oiio_type(self, mx_basetype):
268343
return return_val
269344

270345
def _materialx_type_to_np_type(self, mx_basetype):
271-
"""Map MaterialX base type to NumPy dtype."""
346+
"""Map MaterialX base type to NumPy dtype with explicit widths."""
272347
type_mapping = {
273-
mx_render.BaseType.UINT8: 'uint8',
274-
mx_render.BaseType.UINT16: 'uint16',
275-
mx_render.BaseType.INT8: 'int8',
276-
mx_render.BaseType.INT16: 'int16',
277-
mx_render.BaseType.HALF: 'half',
278-
mx_render.BaseType.FLOAT: 'float',
348+
mx_render.BaseType.UINT8: np.uint8,
349+
mx_render.BaseType.UINT16: np.uint16,
350+
mx_render.BaseType.INT8: np.int8,
351+
mx_render.BaseType.INT16: np.int16,
352+
mx_render.BaseType.HALF: np.float16,
353+
mx_render.BaseType.FLOAT: np.float32,
279354
}
280355
return type_mapping.get(mx_basetype, None)
281356

@@ -287,6 +362,7 @@ def test_load_save():
287362
parser = argparse.ArgumentParser(description="MaterialX OIIO Image Handler")
288363
parser.add_argument("path", help="Path to the image file")
289364
parser.add_argument("--flip", action="store_true", help="Flip the image vertically")
365+
parser.add_argument("--preview", action="store_true", help="Preview the image before saving")
290366
args = parser.parse_args()
291367

292368
test_image_path = args.path
@@ -296,15 +372,15 @@ def test_load_save():
296372

297373
# Create MaterialX handler with custom OIIO image loader
298374
loader = OiioImageLoader()
299-
#handler = mx_render.ImageHandler.create(loader)
300-
manager = mx_render.getPluginManager()
301-
handler = manager.getImageHandler()
302-
logger.info(f"Got handler from plugin manager {handler}")
375+
loader.set_preview(args.preview)
376+
handler = mx_render.ImageHandler.create(loader)
377+
logger.info(f"Created image handler with loader ({loader.get_identifier()}): {handler is not None}")
303378
handler.addLoader(loader)
304379

305380
mx_filepath = mx.FilePath(test_image_path)
306381

307382
# Load image using handler API
383+
logger.info('-'*45)
308384
logger.info(f"Loading image from path: {mx_filepath.asString()}")
309385
mx_image = handler.acquireImage(mx_filepath)
310386
if mx_image:
@@ -319,12 +395,18 @@ def test_load_save():
319395
logger.info(f" Base type: {mx_image.getBaseType()}")
320396

321397
# Save image using handler API (to a new file)
322-
logger.info('*'*45)
323-
out_path = mx.FilePath("saved_" + os.path.basename(test_image_path))
324-
if handler.saveImage(out_path, mx_image, verticalFlip=args.flip):
325-
logger.info(f"MaterialX Image saved to {out_path.asString()}")
398+
logger.info('-'*45)
399+
400+
# Retrieve cached image
401+
mx_image = handler.acquireImage(mx_filepath)
402+
if mx_image:
403+
out_path = mx.FilePath("saved_" + os.path.basename(test_image_path))
404+
if handler.saveImage(out_path, mx_image, verticalFlip=args.flip):
405+
logger.info(f"MaterialX Image saved to {out_path.asString()}")
406+
else:
407+
logger.error("Failed to save image.")
326408
else:
327-
logger.error("Failed to save image.")
409+
logger.error("Failed to acquire image for saving.")
328410
else:
329411
logger.error("Failed to load image.")
330412

0 commit comments

Comments
 (0)