diff --git a/include/neural-graphics-primitives/common_device.cuh b/include/neural-graphics-primitives/common_device.cuh
index dfcdbb791..88c56c599 100644
--- a/include/neural-graphics-primitives/common_device.cuh
+++ b/include/neural-graphics-primitives/common_device.cuh
@@ -357,15 +357,6 @@ inline NGP_HOST_DEVICE Ray pixel_to_ray_pinhole(
 	return {camera_matrix[3], dir};
 }
 
-inline NGP_HOST_DEVICE mat4x3 get_xform_given_rolling_shutter(const TrainingXForm& training_xform, const vec4& rolling_shutter, const vec2& uv, float motionblur_time) {
-	float pixel_t = rolling_shutter.x + rolling_shutter.y * uv.x + rolling_shutter.z * uv.y + rolling_shutter.w * motionblur_time;
-
-	vec3 pos = training_xform.start[3] + (training_xform.end[3] - training_xform.start[3]) * pixel_t;
-	mat3 rot = to_mat3(normalize(slerp(quat(mat3(training_xform.start)), quat(mat3(training_xform.end)), pixel_t)));
-
-	return mat4x3(rot[0], rot[1], rot[2], pos);
-}
-
 inline NGP_HOST_DEVICE vec3 f_theta_undistortion(const vec2& uv, const float* params, const vec3& error_direction) {
 	// we take f_theta intrinsics to be: r0, r1, r2, r3, resx, resy; we rescale to whatever res the intrinsics specify.
 	float xpix = uv.x * params[5];
@@ -663,6 +654,12 @@ inline NGP_HOST_DEVICE mat4x3 camera_slerp(const mat4x3& a, const mat4x3& b, flo
 	return {rot[0], rot[1], rot[2], mix(a[3], b[3], t)};
 }
 
+inline NGP_HOST_DEVICE mat4x3 get_xform_given_rolling_shutter(const TrainingXForm& training_xform, const vec4& rolling_shutter, const vec2& uv, float motionblur_time) {
+	float pixel_t = rolling_shutter.x + rolling_shutter.y * uv.x + rolling_shutter.z * uv.y + rolling_shutter.w * motionblur_time;
+	return camera_log_lerp(training_xform.start, training_xform.end, pixel_t);
+	// return camera_slerp(training_xform.start, training_xform.end, pixel_t);
+}
+
 inline NGP_HOST_DEVICE void apply_quilting(uint32_t* x, uint32_t* y, const ivec2& resolution, vec3& parallax_shift, const ivec2& quilting_dims) {
 	float resx = float(resolution.x) / quilting_dims.x;
 	float resy = float(resolution.y) / quilting_dims.y;