Reconfigure how sampling is done so that we do it in 16x16 chunks

This commit is contained in:
Eryn Wells 2018-10-27 10:14:23 -07:00
parent af3031ece6
commit a14943c42f
3 changed files with 48 additions and 20 deletions

View file

@ -35,11 +35,15 @@ class MarchingSquares {
private(set) var gridGeometry: MTLBuffer? private(set) var gridGeometry: MTLBuffer?
private var xSamples: Int { private var xSamples: Int {
return Int(field.size.x / sampleGridSize.x) let xSize = field.size.x / sampleGridSize.x
let xRem = field.size.x % sampleGridSize.x
return Int(xSize + (sampleGridSize.x - xRem))
} }
private var ySamples: Int { private var ySamples: Int {
return Int(field.size.y / sampleGridSize.y) let ySize = field.size.y / sampleGridSize.y
let yRem = field.size.y % sampleGridSize.y
return Int(ySize + (sampleGridSize.y - yRem))
} }
private var lastSamplesCount = 0 private var lastSamplesCount = 0
@ -49,9 +53,12 @@ class MarchingSquares {
} }
var contourIndexesCount: Int { var contourIndexesCount: Int {
return (xSamples - 1) * (ySamples - 1) return samplesCount
} }
/// Threadgroup size for the compute kernels.
private let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
init(field: Field) { init(field: Field) {
self.field = field self.field = field
semaphore = DispatchSemaphore(value: 1) semaphore = DispatchSemaphore(value: 1)
@ -136,6 +143,7 @@ class MarchingSquares {
func populateGrid(withDevice device: MTLDevice) { func populateGrid(withDevice device: MTLDevice) {
guard lastSamplesCount != samplesCount else { guard lastSamplesCount != samplesCount else {
print("Populate requested, but lastSampleCount(\(lastSamplesCount) == samplesCount(\(samplesCount))")
return return
} }
@ -181,9 +189,7 @@ class MarchingSquares {
encoder.setBuffer(samplesBuffer, offset: 0, index: 2) encoder.setBuffer(samplesBuffer, offset: 0, index: 2)
// Dispatch! // Dispatch!
// TODO: Kernel threadgroup size limit is 256. Figure out how to make this work. let gridSize = computeGridSize(forCellGridSize: Size(x: UInt32(xSamples), y: UInt32(ySamples)))
let gridSize = MTLSize(width: xSamples, height: ySamples, depth: 1)
let threadgroupSize = MTLSize(width: xSamples, height: 1, depth: 1)
encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding() encoder.endEncoding()
@ -206,12 +212,24 @@ class MarchingSquares {
encoder.setBuffer(contourIndexesBuffer, offset: 0, index: 2) encoder.setBuffer(contourIndexesBuffer, offset: 0, index: 2)
// Dispatch! // Dispatch!
let gridSize = MTLSize(width: contourIndexesCount, height: 1, depth: 1) let gridSize = computeGridSize(forCellGridSize: Size(x: UInt32(xSamples - 1), y: UInt32(ySamples - 1)))
let threadgroupSize = MTLSize(width: xSamples - 1, height: 1, depth: 1)
encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding() encoder.endEncoding()
} }
/// Grid size for the compute kernels.
func computeGridSize(forCellGridSize gridSize: Size) -> MTLSize {
let xs = Int(gridSize.x)
let ys = Int(gridSize.y)
let xrem = xs % threadgroupSize.width
let yrem = ys % threadgroupSize.height
// Our compute grid size is the next multiple of threadgroupSize larger than the current cell grid size.
let gridSize = MTLSize(width: xs + (threadgroupSize.width - xrem),
height: ys + (threadgroupSize.height - yrem),
depth: 1)
return gridSize
}
} }
struct Variants { struct Variants {

View file

@ -206,7 +206,7 @@ public class Renderer: NSObject, MTKViewDelegate {
if self.pixelGeometry == nil { if self.pixelGeometry == nil {
self.pixelGeometry = self.pixelGeometry(forViewSize: view.drawableSize) self.pixelGeometry = self.pixelGeometry(forViewSize: view.drawableSize)
} }
let pixelGeometry = self.pixelGeometry! // let pixelGeometry = self.pixelGeometry!
if let renderPass = view.currentRenderPassDescriptor { if let renderPass = view.currentRenderPassDescriptor {
// Render the per-pixel metaballs // Render the per-pixel metaballs
@ -240,7 +240,7 @@ public class Renderer: NSObject, MTKViewDelegate {
encoder.setVertexBuffer(marchingSquares.gridGeometry, offset: 0, index: 1) encoder.setVertexBuffer(marchingSquares.gridGeometry, offset: 0, index: 1)
encoder.setVertexBuffer(parametersBuffer, offset: 0, index: 2) encoder.setVertexBuffer(parametersBuffer, offset: 0, index: 2)
encoder.setFragmentBuffer(marchingSquares.contourIndexesBuffer, offset: 0, index: 0) encoder.setFragmentBuffer(marchingSquares.contourIndexesBuffer, offset: 0, index: 0)
encoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: Rect.geometry.count, instanceCount: marchingSquares.samplesCount) encoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: Rect.geometry.count, instanceCount: marchingSquares.contourIndexesCount)
encoder.endEncoding() encoder.endEncoding()
didEncode = true didEncode = true
} }

View file

@ -45,6 +45,11 @@ samplingKernel(constant MarchingSquaresParameters &parameters [[buffer(0)]],
device float *samples [[buffer(2)]], device float *samples [[buffer(2)]],
uint2 position [[thread_position_in_grid]]) uint2 position [[thread_position_in_grid]])
{ {
if (position.x >= parameters.gridSize.x || position.y >= parameters.gridSize.y)
{
return;
}
// Find the midpoint of this grid cell. // Find the midpoint of this grid cell.
const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0), const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0),
position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0)); position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0));
@ -68,38 +73,43 @@ kernel void
contouringKernel(constant MarchingSquaresParameters &parameters [[buffer(0)]], contouringKernel(constant MarchingSquaresParameters &parameters [[buffer(0)]],
constant float *samples [[buffer(1)]], constant float *samples [[buffer(1)]],
device ushort *contourIndexes [[buffer(2)]], device ushort *contourIndexes [[buffer(2)]],
uint position [[thread_position_in_grid]]) uint2 position [[thread_position_in_grid]])
{ {
if (position.x >= (parameters.gridSize.x - 1) || position.y >= (parameters.gridSize.y - 1)) {
return;
}
// Calculate an index based on the samples at the four points around this cell. // Calculate an index based on the samples at the four points around this cell.
// If the point is above the threshold, adjust the value accordingly. // If the point is above the threshold, adjust the value accordingly.
// d--c 8--4 // d--c 8--4
// | | -> | | // | | -> | |
// a--b 1--2 // a--b 1--2
uint a = position + parameters.gridSize.x; uint rowSize = parameters.gridSize.x - 1;
uint b = position + parameters.gridSize.x + 1; uint d = position.y * rowSize + position.x;
uint c = position + 1; uint c = d + 1;
uint d = position; uint b = d + rowSize + 1;
uint a = d + rowSize;
uint index = (samples[d] >= 1.0 ? 0b1000 : 0) + uint index = (samples[d] >= 1.0 ? 0b1000 : 0) +
(samples[c] >= 1.0 ? 0b0100 : 0) + (samples[c] >= 1.0 ? 0b0100 : 0) +
(samples[b] >= 1.0 ? 0b0010 : 0) + (samples[b] >= 1.0 ? 0b0010 : 0) +
(samples[a] >= 1.0 ? 0b0001 : 0); (samples[a] >= 1.0 ? 0b0001 : 0);
contourIndexes[position] = index; contourIndexes[d] = index;
} }
vertex RasterizerData vertex RasterizerData
gridVertexShader(constant Vertex *vertexes [[buffer(0)]], gridVertexShader(constant Vertex *vertexes [[buffer(0)]],
constant Rect *rects [[buffer(1)]], constant Rect *cells [[buffer(1)]],
constant RenderParameters &renderParameters [[buffer(2)]], constant RenderParameters &renderParameters [[buffer(2)]],
uint vid [[vertex_id]], uint vid [[vertex_id]],
uint instid [[instance_id]]) uint instid [[instance_id]])
{ {
Vertex v = vertexes[vid]; Vertex v = vertexes[vid];
Rect rect = rects[instid]; Rect cell = cells[instid];
RasterizerData out; RasterizerData out;
out.position = renderParameters.projection * rect.transform * float4(v.position.xy, 0, 1); out.position = renderParameters.projection * cell.transform * float4(v.position.xy, 0, 1);
out.color = rect.color; out.color = cell.color;
out.textureCoordinate = v.textureCoordinate; out.textureCoordinate = v.textureCoordinate;
out.instance = instid; out.instance = instid;
return out; return out;