Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions core/src/main/scala/geotrellis/raster/ArrowTensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ import org.apache.arrow.vector.{Float8Vector, VectorSchemaRoot}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.locationtech.rasterframes.encoders.CatalystSerializerEncoder
import org.locationtech.rasterframes.tensors.RFTensor

import spire.syntax.cfor._

import scala.collection.JavaConverters._

case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends CellGrid {


case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends RFTensor {
// TODO: Should we be using ArrowBuf here directly, since Arrow Tensor can not have pages?
def rows = shape(1)
def cols = shape(2)
val cellType = DoubleCellType

// TODO: Figure out how to work this crazy thing
// def copy(implicit alloc: BufferAllocator) = {
Expand Down Expand Up @@ -189,8 +189,8 @@ case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends Ce

object ArrowTensor {
import org.apache.spark.sql.rf.TensorUDT._
implicit val arrowTensorEncoder: ExpressionEncoder[ArrowTensor] =
CatalystSerializerEncoder[ArrowTensor](true)
implicit val arrowTensorEncoder: ExpressionEncoder[RFTensor] =
CatalystSerializerEncoder[RFTensor](true)

val allocator = new RootAllocator(Long.MaxValue)

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/geotrellis/raster/BufferedTensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ case class BufferedTensor(
val bufferRows: Int,
val bufferCols: Int,
val extent: Option[Extent]
) extends CellGrid {
) extends RFTensor {

val cellType = DoubleCellType

Expand Down
50 changes: 31 additions & 19 deletions core/src/main/scala/org/apache/spark/sql/rf/TensorUDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,32 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, _}
import org.locationtech.rasterframes.encoders.CatalystSerializer
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.model.{Cells, TileDataContext}
import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile
import org.locationtech.rasterframes.tiles.InternalRowTile
import org.locationtech.rasterframes.model.{Voxels, TensorDataContext}
import org.locationtech.rasterframes.ref.DeferredTensorRef
//import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile
import org.locationtech.rasterframes.tensors.{RFTensor, InternalRowTensor}


@SQLUserDefinedType(udt = classOf[TensorUDT])
class TensorUDT extends UserDefinedType[ArrowTensor] {
class TensorUDT extends UserDefinedType[RFTensor] {
import TensorUDT._
override def typeName = TensorUDT.typeName

override def pyUDT: String = "pyrasterframes.rf_types.TensorUDT"

def userClass: Class[ArrowTensor] = classOf[ArrowTensor]
def userClass: Class[RFTensor] = classOf[RFTensor]

def sqlType: StructType = schemaOf[ArrowTensor]
def sqlType: StructType = schemaOf[RFTensor]

override def serialize(obj: ArrowTensor): InternalRow =
override def serialize(obj: RFTensor): InternalRow =
Option(obj)
.map(_.toInternalRow)
.orNull

override def deserialize(datum: Any): ArrowTensor =
override def deserialize(datum: Any): RFTensor =
Option(datum)
.collect {
case ir: InternalRow ⇒ ir.to[ArrowTensor]
case ir: InternalRow ⇒ ir.to[RFTensor]
}
.orNull

Expand All @@ -61,23 +62,34 @@ class TensorUDT extends UserDefinedType[ArrowTensor] {
}

case object TensorUDT {
UDTRegistration.register(classOf[ArrowTensor].getName, classOf[TensorUDT].getName)
UDTRegistration.register(classOf[RFTensor].getName, classOf[TensorUDT].getName)

final val typeName: String = "tensor"

implicit def tensorSerializer: CatalystSerializer[ArrowTensor] = new CatalystSerializer[ArrowTensor] {
implicit def tensorSerializer: CatalystSerializer[RFTensor] = new CatalystSerializer[RFTensor] {

override val schema: StructType = StructType(Seq(
StructField("arrow_tensor", BinaryType, true)
StructField("tensor_context", schemaOf[TensorDataContext], true),
StructField("tensor_data", schemaOf[Voxels], false)
))

override def to[R](t: ArrowTensor, io: CatalystIO[R]): R = io.create {
t.toArrowBytes()
}

override def from[R](row: R, io: CatalystIO[R]): ArrowTensor = {
val bytes = io.getByteArray(row, 0)
ArrowTensor.fromArrowMessage(bytes)
override def to[R](t: RFTensor, io: CatalystIO[R]): R = io.create(
t match {
case _: DeferredTensorRef => null
case o => io.to(TensorDataContext(o))
},
io.to(Voxels(t))
)

override def from[R](row: R, io: CatalystIO[R]): RFTensor = {
val voxels = io.get[Voxels](row, 1)

row match {
case ir: InternalRow if !voxels.isRef ⇒ new InternalRowTensor(ir)
case _ ⇒
val ctx = io.get[TensorDataContext](row, 0)
voxels.toTensor
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
import org.apache.spark.sql.catalyst.expressions.BinaryExpression
import org.apache.spark.sql.rf.{TileUDT, TensorUDT}
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes.tensors.RFTensor
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -92,9 +93,9 @@ trait BinaryLocalRasterOp extends BinaryExpression {
}

(context, isTensor) match {
case (Some(ctx), true) ⇒ result.asInstanceOf[ArrowTensor].toInternalRow
case (Some(ctx), true) ⇒ result.asInstanceOf[RFTensor].toInternalRow
case (Some(ctx), false) ⇒ ctx.toProjectRasterTile(result.asInstanceOf[Tile]).toInternalRow
case (None, true) ⇒ result.asInstanceOf[ArrowTensor].toInternalRow
case (None, true) ⇒ result.asInstanceOf[RFTensor].toInternalRow
case (None, false) ⇒ result.asInstanceOf[Tile].toInternalRow
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.locationtech.jts.geom.{Envelope, Point}
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.model.{LazyCRS, TileContext}
import org.locationtech.rasterframes.ref.{ProjectedRasterLike, RasterRef, RasterSource}
import org.locationtech.rasterframes.tensors.RFTensor
import org.locationtech.rasterframes.tiles.ProjectedRasterTile

private[rasterframes]
Expand Down Expand Up @@ -95,7 +96,7 @@ object DynamicExtractors {
case _: TileUDT =>
(row: InternalRow) => row.to[Tile](TileUDT.tileSerializer)
case _: TensorUDT =>
(row: InternalRow) => row.to[ArrowTensor](TensorUDT.tensorSerializer)
(row: InternalRow) => row.to[RFTensor](TensorUDT.tensorSerializer)
case _: BufferedTensorUDT =>
(row: InternalRow) => row.to[BufferedTensor](BufferedTensorUDT.bufferedTensorSerializer)
case _: RasterSourceUDT =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import scala.util.control.NonFatal
* @since 9/6/18
*/
// BUFFER HERE
case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[TileDimensions] = None) extends UnaryExpression
case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[TileDimensions] = None, bufferPixels: Int = 0) extends UnaryExpression
with Generator with CodegenFallback with ExpectsInputTypes {
import TensorRef._
import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSources._
Expand Down Expand Up @@ -78,14 +78,11 @@ case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[Tile

val sampleRS = rss.head.source

val maybeSubs = subtileDims.map { dims =>
val subGB = sampleRS.layoutBounds(dims)
subGB.map(gb => (gb, sampleRS.rasterExtent.extentFor(gb, clamp = true)))
}
val maybeSubs = subtileDims.map { dims => sampleRS.layoutBounds(dims) }

val trefs = maybeSubs.map { subs =>
subs.map { case (gb, extent) => TensorRef(rss, Some(extent), Some(gb)) }
}.getOrElse(Seq(TensorRef(rss, None, None)))
subs.map { gb => TensorRef(rss, Some(gb), bufferPixels) }
}.getOrElse(Seq(TensorRef(rss, None, bufferPixels)))

trefs.map{ tref => InternalRow(tref.toInternalRow) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ case class TensorRefToTensor(child: Expression, bufferPixels: Int) extends Unary
override protected def nullSafeEval(input: Any): Any = {
implicit val ser = ProjectedBufferedTensor.serializer
val ref = row(input).to[TensorRef]
val realized = ref.realizedTensor(bufferPixels)

realized.toInternalRow
ref.tensor.toInternalRow
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2019 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.model

import org.locationtech.rasterframes.encoders.CatalystSerializer._
import geotrellis.raster.{CellType, Tile, ArrowTensor}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder}

/** Encapsulates all information about a tile aside from actual cell values. */
case class TensorDataContext(depth: Int, rows: Int, cols: Int, bufferPixels: Int)
object TensorDataContext {

/** Extracts the TensorDataContext from a Tile. */
def apply(t: RFTensor): TensorDataContext = {
require(t.depth <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.depth}")
require(t.cols <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.cols}")
require(t.rows <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.rows}")
TensorDataContext(t.shape(0), t.shape(1), t.shape(2), bufferPixels)
}

implicit val serializer: CatalystSerializer[TensorDataContext] = new CatalystSerializer[TensorDataContext] {
override val schema: StructType = StructType(Seq(
StructField("depth", IntegerType, false),
StructField("cols", IntegerType, false),
StructField("rows", IntegerType, false),
StructField("bufferPixels", IntegerType, false)
))

override protected def to[R](t: TensorDataContext, io: CatalystIO[R]): R = io.create(
t.depth,
t.rows,
t.cols,
t.bufferPixels
)
override protected def from[R](t: R, io: CatalystIO[R]): TensorDataContext = TensorDataContext(
io.getInt(t, 0),
io.getInt(t, 1),
io.getInt(t, 2),
io.getInt(t, 3)
)
}

implicit def encoder: ExpressionEncoder[TensorDataContext] = CatalystSerializerEncoder[TensorDataContext]()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2019 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.model

import geotrellis.raster._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
import org.locationtech.rasterframes
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder}
import org.locationtech.rasterframes.ref.{TensorRef, DeferredTensorRef}
import org.locationtech.rasterframes.tensors.RFTensor

/** Represents the union of binary cell datas or a reference to the data.*/
case class Voxels(data: Either[Array[Byte], TensorRef]) {
def isRef: Boolean = data.isRight

/** Convert voxels into either a DeferredTensorRef or an ArrowTensor. */
def toTensor(ctx: TensorDataContext): RFTensor = data.fold(
bytes => {
val nakedTensor = ArrowTensor.fromArrowMessage(bytes)
BufferedTensor(ArrowTensor.fromArrowMessage, ctx.bufferPixels)
},
ref => DeferredTensorRef(ref)
)
}

object Voxels {
/** Extracts the Voxels from a Tensor. */
def apply(t: RFTensor): Voxels = {
t match {
case arrowTensor: ArrowTensor =>
Voxels(Left(arrowTensor.toArrowBytes()))
case ref: DeferredTensorRef =>
Voxels(Right(ref.deferred))
case const: ConstantTile =>
throw new IllegalArgumentException
}
}

implicit def voxelsSerializer: CatalystSerializer[Voxels] = new CatalystSerializer[Voxels] {
override val schema: StructType =
StructType(
Seq(
StructField("voxels", BinaryType, true),
StructField("ref", schemaOf[TensorRef], true)
))
override protected def to[R](t: Voxels, io: CatalystSerializer.CatalystIO[R]): R = io.create(
t.data.left.getOrElse(null),
t.data.right.map(tr => io.to(tr)).right.getOrElse(null)
)
override protected def from[R](t: R, io: CatalystSerializer.CatalystIO[R]): Voxels = {
if (!io.isNullAt(t, 0))
Voxels(Left(io.getByteArray(t, 0)))
else if (!io.isNullAt(t, 1))
Voxels(Right(io.get[TensorRef](t, 1)))
else throw new IllegalArgumentException("must be either arrow tensor data or a ref, but not null")
}
}

implicit def encoder: ExpressionEncoder[Voxels] = CatalystSerializerEncoder[Voxels]()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2018 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.ref

import com.typesafe.scalalogging.LazyLogging
import geotrellis.proj4.CRS
import geotrellis.raster._
import geotrellis.vector.{Extent, ProjectedExtent}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.rf.RasterSourceUDT
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, ArrayType}
import org.apache.spark.sql.Encoder
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _}
import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder}
import org.locationtech.rasterframes.ref.RasterSource._
import org.locationtech.rasterframes.tensors.{ProjectedBufferedTensor, RFTensor}
import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSources._


case class DeferredTensorRef(deferred: TensorRef) extends RFTensor {
def shape = Seq(deferred.sources.length, deferred.rows, deferred.cols)
}
Loading